Showing preview only (1,371K chars total). Download the full file or copy to clipboard to get everything.
Repository: DLR-RM/stable-baselines3
Branch: master
Commit: a72be407aa61
Files: 170
Total size: 1.3 MB
Directory structure:
gitextract_vxq3k1wk/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.yml
│ │ ├── custom_env.yml
│ │ ├── documentation.yml
│ │ ├── feature_request.yml
│ │ └── question.yml
│ ├── PULL_REQUEST_TEMPLATE.md
│ └── workflows/
│ └── ci.yml
├── .gitignore
├── .readthedocs.yml
├── CITATION.bib
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── Makefile
├── NOTICE
├── README.md
├── docs/
│ ├── Makefile
│ ├── README.md
│ ├── _static/
│ │ └── css/
│ │ └── baselines_theme.css
│ ├── common/
│ │ ├── atari_wrappers.md
│ │ ├── distributions.md
│ │ ├── env_checker.md
│ │ ├── env_util.md
│ │ ├── envs.md
│ │ ├── evaluation.md
│ │ ├── logger.md
│ │ ├── monitor.md
│ │ ├── noise.md
│ │ └── utils.md
│ ├── conda_env.yml
│ ├── conf.py
│ ├── guide/
│ │ ├── algos.md
│ │ ├── callbacks.md
│ │ ├── checking_nan.md
│ │ ├── custom_env.md
│ │ ├── custom_policy.md
│ │ ├── developer.md
│ │ ├── examples.md
│ │ ├── export.md
│ │ ├── imitation.md
│ │ ├── install.md
│ │ ├── integrations.md
│ │ ├── migration.md
│ │ ├── plotting.md
│ │ ├── quickstart.md
│ │ ├── rl.md
│ │ ├── rl_tips.md
│ │ ├── rl_zoo.md
│ │ ├── save_format.md
│ │ ├── sb3_contrib.md
│ │ ├── sbx.md
│ │ ├── tensorboard.md
│ │ └── vec_envs.md
│ ├── index.rst
│ ├── make.bat
│ ├── misc/
│ │ ├── changelog.md
│ │ └── projects.md
│ ├── modules/
│ │ ├── a2c.md
│ │ ├── base.md
│ │ ├── ddpg.md
│ │ ├── dqn.md
│ │ ├── her.md
│ │ ├── ppo.md
│ │ ├── sac.md
│ │ └── td3.md
│ └── spelling_wordlist.txt
├── pyproject.toml
├── scripts/
│ ├── build_docker.sh
│ ├── run_docker_cpu.sh
│ ├── run_docker_gpu.sh
│ └── run_tests.sh
├── setup.py
├── stable_baselines3/
│ ├── __init__.py
│ ├── a2c/
│ │ ├── __init__.py
│ │ ├── a2c.py
│ │ └── policies.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── atari_wrappers.py
│ │ ├── base_class.py
│ │ ├── buffers.py
│ │ ├── callbacks.py
│ │ ├── distributions.py
│ │ ├── env_checker.py
│ │ ├── env_util.py
│ │ ├── envs/
│ │ │ ├── __init__.py
│ │ │ ├── bit_flipping_env.py
│ │ │ ├── identity_env.py
│ │ │ └── multi_input_envs.py
│ │ ├── evaluation.py
│ │ ├── logger.py
│ │ ├── monitor.py
│ │ ├── noise.py
│ │ ├── off_policy_algorithm.py
│ │ ├── on_policy_algorithm.py
│ │ ├── policies.py
│ │ ├── preprocessing.py
│ │ ├── results_plotter.py
│ │ ├── running_mean_std.py
│ │ ├── save_util.py
│ │ ├── sb2_compat/
│ │ │ ├── __init__.py
│ │ │ └── rmsprop_tf_like.py
│ │ ├── torch_layers.py
│ │ ├── type_aliases.py
│ │ ├── utils.py
│ │ └── vec_env/
│ │ ├── __init__.py
│ │ ├── base_vec_env.py
│ │ ├── dummy_vec_env.py
│ │ ├── patch_gym.py
│ │ ├── stacked_observations.py
│ │ ├── subproc_vec_env.py
│ │ ├── util.py
│ │ ├── vec_check_nan.py
│ │ ├── vec_extract_dict_obs.py
│ │ ├── vec_frame_stack.py
│ │ ├── vec_monitor.py
│ │ ├── vec_normalize.py
│ │ ├── vec_transpose.py
│ │ └── vec_video_recorder.py
│ ├── ddpg/
│ │ ├── __init__.py
│ │ ├── ddpg.py
│ │ └── policies.py
│ ├── dqn/
│ │ ├── __init__.py
│ │ ├── dqn.py
│ │ └── policies.py
│ ├── her/
│ │ ├── __init__.py
│ │ ├── goal_selection_strategy.py
│ │ └── her_replay_buffer.py
│ ├── ppo/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── ppo.py
│ ├── py.typed
│ ├── sac/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── sac.py
│ ├── td3/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── td3.py
│ └── version.txt
└── tests/
├── __init__.py
├── test_buffers.py
├── test_callbacks.py
├── test_cnn.py
├── test_custom_policy.py
├── test_deterministic.py
├── test_dict_env.py
├── test_distributions.py
├── test_env_checker.py
├── test_envs.py
├── test_gae.py
├── test_her.py
├── test_identity.py
├── test_logger.py
├── test_monitor.py
├── test_n_step_replay.py
├── test_predict.py
├── test_preprocessing.py
├── test_run.py
├── test_save_load.py
├── test_sde.py
├── test_spaces.py
├── test_tensorboard.py
├── test_train_eval_mode.py
├── test_utils.py
├── test_vec_check_nan.py
├── test_vec_envs.py
├── test_vec_extract_dict_obs.py
├── test_vec_monitor.py
├── test_vec_normalize.py
└── test_vec_stacked_obs.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yml
================================================
name: "\U0001F41B Bug Report"
description: If you encounter an unexpected behavior, software crash, or other bug.
title: "[Bug]: bug title"
labels: ["bug"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
If your issue is related to a **custom gym environment**, please use the custom gym env template.
- type: textarea
id: description
attributes:
label: 🐛 Bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: To Reproduce
description: |
Steps to reproduce the behavior. Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
value: |
```python
from stable_baselines3 import ...
```
- type: textarea
id: traceback
attributes:
label: Relevant log output / Error message
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
placeholder: "Traceback (most recent call last): File ..."
render: shell
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Describe the characteristic of your environment:
* Describe how the library was installed (pip, docker, source, ...)
* GPU models and configuration
* Python version
* PyTorch version
* Gymnasium version
* (if installed) OpenAI Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
required: true
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
required: true
- label: I have provided a [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) example to reproduce the bug
required: true
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/custom_env.yml
================================================
name: "\U0001F916 Custom Gym Environment Issue"
description: If your problem involves a custom gym environment.
labels: ["custom gym env"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
**Please check your environment first using**:
```python
from stable_baselines3.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```
- type: textarea
id: description
attributes:
label: 🐛 Bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
id: code-example
attributes:
label: Code example
description: |
Please try to provide a [minimal example](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) to reproduce the bug.
For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods (see working example below).
Error messages and stack traces are also helpful.
Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
value: |
```python
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
class CustomEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
def reset(self, seed=None, options=None):
return self.observation_space.sample(), {}
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
terminated = False
truncated = False
info = {}
return obs, reward, terminated, truncated, info
env = CustomEnv()
check_env(env)
model = A2C("MlpPolicy", env, verbose=1).learn(1000)
```
- type: textarea
id: traceback
attributes:
label: Relevant log output / Error message
description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
placeholder: "Traceback (most recent call last): File ..."
render: shell
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Describe the characteristic of your environment:
* Describe how the library was installed (pip, docker, source, ...)
* GPU models and configuration
* Python version
* PyTorch version
* Gymnasium version
* (if installed) OpenAI Gym version
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
required: true
- label: I have provided a [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014) example to reproduce the bug
required: true
- label: I have checked my env using the env checker
required: true
- label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/documentation.yml
================================================
name: "\U0001F4DA Documentation"
description: If you want to improve the documentation by reporting errors, inconsistencies, or missing information.
labels: ["documentation"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: description
attributes:
label: 📚 Documentation
description: A clear and concise description of what should be improved in the documentation.
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.yml
================================================
name: "\U0001F680 Feature Request"
description: If you have an idea for a new feature or an improvement.
title: "[Feature Request] request title"
labels: ["enhancement"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: description
attributes:
label: 🚀 Feature
description: A clear and concise description of the feature proposal.
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too.
- type: textarea
id: pitch
attributes:
label: Pitch
description: A clear and concise description of what you want to happen.
- type: textarea
id: alternatives
attributes:
label: Alternatives
description: A clear and concise description of any alternative solutions or features you've considered, if any.
- type: textarea
id: additional-context
attributes:
label: Additional context
description: Add any other context or screenshots about the feature request here.
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: If I'm requesting a new feature, I have proposed alternatives
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/question.yml
================================================
name: "❓ Question"
description: If you have a general question about Stable-Baselines3.
title: "[Question] question title"
labels: ["question"]
body:
- type: markdown
attributes:
value: |
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
- type: textarea
id: question
attributes:
label: ❓ Question
description: |
Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first.
**Important Note: If your question is anything like "Why is my code generating this error?", you must [submit a bug report](https://github.com/DLR-RM/stable-baselines3/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+bug+title) instead.**
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
required: true
- label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
required: true
- label: If code there is, it is [minimal and working](https://github.com/DLR-RM/stable-baselines3/issues/982#issuecomment-1197044014)
required: true
- label: If code there is, it is formatted using the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
required: true
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
<!--- Provide a general summary of your changes in the Title above -->
## Description
<!--- Describe your changes in detail -->
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
<!--- You can use the syntax `closes #100` if this solves the issue #100 -->
- [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes)
## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Documentation (update in the documentation)
## Checklist
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
- [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**)
- [ ] I have updated the changelog accordingly (`docs/misc/changelog.md`) (**required**).
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
- [ ] I have updated the documentation accordingly.
- [ ] I have opened an associated PR on the [SB3-Contrib repository](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) (if necessary)
- [ ] I have opened an associated PR on the [RL-Zoo3 repository](https://github.com/DLR-RM/rl-baselines3-zoo) (if necessary)
- [ ] I have reformatted the code using `make format` (**required**)
- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**)
- [ ] I have ensured `make pytest` and `make type` both pass. (**required**)
- [ ] I have checked that the documentation builds using `make doc` (**required**)
Note: You can run most of the checks using `make commit-checks`.
Note: we are using a maximum length of 127 characters per line
<!--- This Template is an edited version of the one from https://github.com/evilsocket/pwnagotchi/ -->
================================================
FILE: .github/workflows/ci.yml
================================================
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: CI
on:
push:
branches: [master]
pull_request:
branches: [master]
jobs:
build:
env:
TERM: xterm-256color
FORCE_COLOR: 1
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
# Need Pytorch 2.9+ for Python 3.13
uv pip install --system torch==2.9.1+cpu --index https://download.pytorch.org/whl/cpu
uv pip install --system .[extra,tests,docs]
# Use headless version
uv pip install --system opencv-python-headless
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
uv pip install --system "numpy<2"
uv pip install --system "ale-py==0.10.1"
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2, ale-py==0.10.1
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
================================================
FILE: .gitignore
================================================
*.swp
*.pyc
*.pkl
*.py~
*.bak
.pytest_cache
.mypy_cache
.DS_Store
.idea
.vscode
.coverage
.coverage.*
__pycache__/
_build/
*.npz
*.pth
.pytype/
git_rewrite_commit_history.sh
# Setuptools distribution and build folders.
/dist/
/build
keys/
# Virtualenv
/env
/venv
*.sublime-project
*.sublime-workspace
.idea
logs/
.ipynb_checkpoints
ghostdriver.log
htmlcov
junk
src
*.egg-info
.cache
*.lprof
*.prof
MUJOCO_LOG.TXT
================================================
FILE: .readthedocs.yml
================================================
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
formats: all
# Set requirements using conda env
conda:
environment: docs/conda_env.yml
build:
os: ubuntu-24.04
tools:
python: "mambaforge-23.11"
================================================
FILE: CITATION.bib
================================================
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socioeconomic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
antonin [dot] raffin [at] dlr [dot] de.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
================================================
FILE: CONTRIBUTING.md
================================================
## Contributing to Stable-Baselines3
**Important: When submitting issues or pull requests, the use of LLM or code assistants (e.g., Claude or Copilot) must be publicly disclosed.**
If you are interested in contributing to Stable-Baselines, your contributions will fall
into two categories:
1. You want to propose a new Feature and implement it
- Create an issue about your intended feature, and we shall discuss the design and
implementation. Once we agree that the plan looks good, go ahead and implement it.
2. You want to implement a feature or bug-fix for an outstanding issue
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted
- Pick an issue or feature and comment on the task that you want to work on this feature.
- If you need more context on a particular issue, please ask, and we shall provide.
Once you finish implementing a feature or bug-fix, please send a Pull Request to
https://github.com/DLR-RM/stable-baselines3
Note: If you do not follow the template (and its mandatory steps), your pull request will be ignored.
If you are not familiar with creating a Pull Request, here are some guides:
- http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request
- https://help.github.com/articles/creating-a-pull-request/
## Developing Stable-Baselines3
To develop Stable-Baselines3 on your machine, here are some tips:
1. Clone a copy of Stable-Baselines3 from source:
```bash
git clone https://github.com/DLR-RM/stable-baselines3
cd stable-baselines3/
```
2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests:
```bash
pip install -e '.[docs,tests,extra]'
```
## Codestyle
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [ruff](https://github.com/astral-sh/ruff) (isort rules) to sort the imports.
For the documentation, we use the default line length of 88 characters per line.
**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
Please document each function/method and [type](https://mypy-lang.org/) them using the following template:
```python
def my_function(arg1: type1, arg2: type2) -> returntype:
"""
Short description of the function.
:param arg1: describe what is arg1
:param arg2: describe what is arg2
:return: describe what is returned
"""
...
return my_variable
```
## Pull Request (PR)
**Important: We do not accept PRs that are fully generated using an LLM/code assistant unless triggered by a maintainer. Use of code assistants (e.g., Claude, Copilot) must be publicly disclosed.**
Before proposing a PR, please open an issue, where the feature will be discussed. This prevents from duplicated PR to be proposed and also ease the code review process.
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave, @Miffyli or @qgallouedec).
A PR must pass the Continuous Integration tests to be merged with the master branch.
## Tests
All new features must add tests in the `tests/` folder ensuring that everything works fine.
We use [pytest](https://pytest.org/).
Also, when a bug fix is proposed, tests should be added to avoid regression.
To run tests with `pytest`:
```
make pytest
```
Type checking with `mypy`:
```
make type
```
Codestyle check with `black`, and `ruff` (`isort` rules):
```
make check-codestyle
make lint
```
To run `type`, `format` and `lint` in one command:
```
make commit-checks
```
Build the documentation:
```
make doc
```
Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that):
```
make spelling
```
## Changelog and Documentation
Please do not forget to update the changelog (`docs/misc/changelog.md`) and add documentation if needed.
You should add your username next to each changelog entry that you added. If this is your first contribution, please add your username at the bottom too.
A README is present in the `docs/` folder for instructions on how to build the documentation.
Credits: this contributing guide is based on the [PyTorch](https://github.com/pytorch/pytorch/) one.
================================================
FILE: Dockerfile
================================================
ARG PARENT_IMAGE=mambaorg/micromamba:2.0-ubuntu24.04
FROM $PARENT_IMAGE
ARG PYTORCH_DEPS=https://download.pytorch.org/whl/cpu
ARG PYTHON_VERSION=3.12
ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found)
# Install micromamba env and dependencies
RUN micromamba install -n base -y python=$PYTHON_VERSION && \
micromamba clean --all --yes
ENV CODE_DIR=/home/$MAMBA_USER
# Copy setup file only to install dependencies
COPY --chown=$MAMBA_USER:$MAMBA_USER ./setup.py ${CODE_DIR}/stable-baselines3/setup.py
COPY --chown=$MAMBA_USER:$MAMBA_USER ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt
RUN cd ${CODE_DIR}/stable-baselines3 && \
pip install uv && \
uv pip install --system torch --default-index ${PYTORCH_DEPS} && \
uv pip install --system -e .[extra,tests,docs] && \
# Use headless version for docker
uv pip uninstall opencv-python && \
uv pip install --system opencv-python-headless && \
pip cache purge && \
uv cache clean
CMD /bin/bash
================================================
FILE: LICENSE
================================================
The MIT License
Copyright (c) 2019 Antonin Raffin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
================================================
FILE: Makefile
================================================
SHELL=/bin/bash
LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
pytest:
./scripts/run_tests.sh
mypy:
mypy ${LINT_PATHS}
missing-annotations:
mypy --disallow-untyped-calls --disallow-untyped-defs --ignore-missing-imports stable_baselines3
# missing docstrings
# pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4
type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff check ${LINT_PATHS} --exit-zero --output-format=concise
format:
# Sort imports
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}
check-codestyle:
# Sort imports
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}
commit-checks: format type lint
doc:
cd docs && make html
spelling:
cd docs && make spelling
clean:
cd docs && make clean
# Build docker images
# If you do export RELEASE=True, it will also push them
docker: docker-cpu docker-gpu
docker-cpu:
./scripts/build_docker.sh
docker-gpu:
USE_GPU=True ./scripts/build_docker.sh
# PyPi package release
release:
python -m build
twine upload dist/*
# Test PyPi package release
test-release:
python -m build
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
.PHONY: clean spelling doc lint format check-codestyle commit-checks
================================================
FILE: NOTICE
================================================
Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines,
both licensed under the MIT License:
before the fork (June 2018):
Copyright (c) 2017 OpenAI (http://openai.com)
after the fork (June 2018):
Copyright (c) 2018-2019 Stable-Baselines Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
================================================
FILE: README.md
================================================
<!-- [](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
[](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[](https://github.com/psf/black)
# Stable Baselines3
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf).
These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.
**Note: Despite its simplicity of use, Stable Baselines3 (SB3) assumes you have some knowledge about Reinforcement Learning (RL).** You should not utilize this library without some practice. To that extent, we provide good resources in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/rl.html) to get started with RL.
## Main Features
**The performance of each algorithm was tested** (see *Results* section in their respective page),
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.
We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.
| **Features** | **Stable-Baselines3** |
| --------------------------- | ----------------------|
| State of the art RL methods | :heavy_check_mark: |
| Documentation | :heavy_check_mark: |
| Custom environments | :heavy_check_mark: |
| Custom policies | :heavy_check_mark: |
| Common interface | :heavy_check_mark: |
| `Dict` observation space support | :heavy_check_mark: |
| Ipython / Notebook friendly | :heavy_check_mark: |
| Tensorboard support | :heavy_check_mark: |
| PEP8 code style | :heavy_check_mark: |
| Custom callback | :heavy_check_mark: |
| High code coverage | :heavy_check_mark: |
| Type hints | :heavy_check_mark: |
### Planned features
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).
While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)
## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)
A migration guide from SB2 to SB3 can be found in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html).
## Documentation
Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/)
## Integrations
Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation.
## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents
[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL).
It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.
In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings.
Goals of this repository:
1. Provide a simple interface to train and enjoy RL agents
2. Benchmark the different Reinforcement Learning algorithms
3. Provide tuned hyperparameters for each environment and RL algorithm
4. Have fun with the trained agents!
Github repo: https://github.com/DLR-RM/rl-baselines3-zoo
Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/
## SB3-Contrib: Experimental RL Features
We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
## Stable-Baselines Jax (SBX)
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ.
It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698
## Installation
**Note:** Stable-Baselines3 supports PyTorch >= 2.3
### Prerequisites
Stable Baselines3 requires Python 3.10+.
#### Windows
To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).
### Install using pip
Install the Stable Baselines3 package:
```sh
pip install 'stable-baselines3[extra]'
```
This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
```sh
pip install stable-baselines3
```
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more details and alternatives (from source, using docker).
## Example
Most of the code in the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
Here is a quick example of how to train and run PPO on a cartpole environment:
```python
import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make("CartPole-v1", render_mode="human")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
# VecEnv resets automatically
# if done:
# obs = env.reset()
env.close()
```
Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
```python
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
```
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples.
## Try it online with Colab Notebooks !
All the following examples can be executed online using Google Colab notebooks:
- [Full Tutorial](https://github.com/araffin/rl-tutorial-jnrr19)
- [All Notebooks](https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3)
- [Getting Started](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb)
- [Training, Saving, Loading](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb)
- [Multiprocessing](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb)
- [Monitor Training and Plotting](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb)
- [Atari Games](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb)
- [RL Baselines Zoo](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb)
- [PyBullet](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb)
## Implemented Algorithms
| **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** |
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| ARS<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| CrossQ<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| QR-DQN<sup>[1](#f1)</sup> | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| RecurrentPPO<sup>[1](#f1)</sup> | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TQC<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TRPO<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Maskable PPO<sup>[1](#f1)</sup> | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.
Actions `gymnasium.spaces`:
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
* `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
## Testing the installation
### Install dependencies
```sh
pip install -e '.[docs,tests,extra]'
```
### Run tests
All unit tests in stable baselines3 can be run using `pytest` runner:
```sh
make pytest
```
To run a single test file:
```sh
python3 -m pytest -v tests/test_env_checker.py
```
To run a single test:
```sh
python3 -m pytest -v -k 'test_check_env_dict_action'
```
You can also do a static type check using `mypy`:
```sh
pip install mypy
make type
```
Codestyle check with `ruff`:
```sh
pip install ruff
make lint
```
## Projects Using Stable-Baselines3
We try to maintain a list of projects using stable-baselines3 in the [documentation](https://stable-baselines3.readthedocs.io/en/master/misc/projects.html),
please tell us if you want your project to appear on this page ;)
## Citing the Project
To cite this repository in publications:
```bibtex
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
```
Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).
## Maintainers
Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
**Important Note: We do not provide technical support, or consulting** and do not answer personal questions via email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/), or [Stack Overflow](https://stackoverflow.com/) in that case.
## How To Contribute
To any interested in making the baselines better, there is still some documentation that needs to be done.
If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first.
## Acknowledgments
The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU's Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)).
The original version, Stable Baselines, was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).
Logo credits: [L.M. Tenkes](https://www.instagram.com/lucillehue/)
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
# For debug: SPHINXOPTS = -nWT --keep-going -vvv
SPHINXOPTS = -W # make warnings fatal
SPHINXBUILD = sphinx-build
SPHINXPROJ = StableBaselines
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/README.md
================================================
# Stable Baselines3 Documentation
This folder contains documentation for the RL baselines.
### Build the Documentation
#### Install Sphinx and Theme
Execute this command in the project root:
```
pip install -e ".[docs]"
```
#### Building the Docs
In the `docs/` folder:
```
make html
```
if you want to building each time a file is changed:
```
sphinx-autobuild . _build/html
```
================================================
FILE: docs/_static/css/baselines_theme.css
================================================
/* Main colors adapted from pytorch doc */
:root{
--main-bg-color: #343A40;
--link-color: #FD7E14;
}
/* Header fonts y */
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
/* Docs background */
.wy-side-nav-search{
background-color: var(--main-bg-color);
}
/* Mobile version */
.wy-nav-top{
background-color: var(--main-bg-color);
}
/* Change link colors (except for the menu) */
a {
color: var(--link-color);
}
a:hover {
color: #4F778F;
}
.wy-menu a {
color: #b3b3b3;
}
.wy-menu a:hover {
color: #b3b3b3;
}
a.icon.icon-home {
color: #b3b3b3;
}
.version{
color: var(--link-color) !important;
}
/* Make code blocks have a background */
.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
background: #f8f8f8;;
}
/* Change style of types in the docstrings .rst-content .field-list */
.field-list .xref.py.docutils, .field-list code.docutils, .field-list .docutils.literal.notranslate
{
border: None;
padding-left: 0;
padding-right: 0;
color: #404040;
}
================================================
FILE: docs/common/atari_wrappers.md
================================================
(atari-wrapper)=
# Atari Wrappers
```{eval-rst}
.. automodule:: stable_baselines3.common.atari_wrappers
:members:
```
================================================
FILE: docs/common/distributions.md
================================================
(distributions)=
# Probability Distributions
Probability distributions used for the different action spaces:
- `CategoricalDistribution` -> Discrete
- `DiagGaussianDistribution` -> Box (continuous actions)
- `StateDependentNoiseDistribution` -> Box (continuous actions) when `use_sde=True`
% - ``MultiCategoricalDistribution`` -> MultiDiscrete
% - ``BernoulliDistribution`` -> MultiBinary
The policy networks output parameters for the distributions (named `flat` in the methods).
Actions are then sampled from those distributions.
For instance, in the case of discrete actions. The policy network outputs probability
of taking each action. The `CategoricalDistribution` allows sampling from it,
computes the entropy, the log probability (`log_prob`) and backpropagate the gradient.
In the case of continuous actions, a Gaussian distribution is used. The policy network outputs
mean and (log) std of the distribution (assumed to be a `DiagGaussianDistribution`).
```{eval-rst}
.. automodule:: stable_baselines3.common.distributions
:members:
```
================================================
FILE: docs/common/env_checker.md
================================================
(env-checker)=
# Gym Environment Checker
```{eval-rst}
.. automodule:: stable_baselines3.common.env_checker
:members:
```
================================================
FILE: docs/common/env_util.md
================================================
(env-util)=
# Environments Utils
```{eval-rst}
.. automodule:: stable_baselines3.common.env_util
:members:
```
================================================
FILE: docs/common/envs.md
================================================
(envs)=
```{eval-rst}
.. automodule:: stable_baselines3.common.envs
```
# Custom Environments
Those environments were created for testing purposes.
## BitFlippingEnv
```{eval-rst}
.. autoclass:: BitFlippingEnv
:members:
```
## SimpleMultiObsEnv
```{eval-rst}
.. autoclass:: SimpleMultiObsEnv
:members:
```
================================================
FILE: docs/common/evaluation.md
================================================
(eval)=
# Evaluation Helper
```{eval-rst}
.. automodule:: stable_baselines3.common.evaluation
:members:
```
================================================
FILE: docs/common/logger.md
================================================
(logger)=
# Logger
To overwrite the default logger, you can pass one to the algorithm.
Available formats are `["stdout", "csv", "log", "tensorboard", "json"]`.
:::{warning}
When passing a custom logger object,
this will overwrite `tensorboard_log` and `verbose` settings
passed to the constructor.
:::
```python
from stable_baselines3 import A2C
from stable_baselines3.common.logger import configure
tmp_path = "/tmp/sb3_log/"
# set up logger
new_logger = configure(tmp_path, ["stdout", "csv", "tensorboard"])
model = A2C("MlpPolicy", "CartPole-v1", verbose=1)
# Set new logger
model.set_logger(new_logger)
model.learn(10000)
```
## Explanation of logger output
You can find below short explanations of the values logged in Stable-Baselines3 (SB3).
Depending on the algorithm used and of the wrappers/callbacks applied, SB3 only logs a subset of those keys during training.
Below you can find an example of the logger output when training a PPO agent:
```bash
-----------------------------------------
| eval/ | |
| mean_ep_length | 200 |
| mean_reward | -157 |
| rollout/ | |
| ep_len_mean | 200 |
| ep_rew_mean | -227 |
| time/ | |
| fps | 972 |
| iterations | 19 |
| time_elapsed | 80 |
| total_timesteps | 77824 |
| train/ | |
| approx_kl | 0.037781604 |
| clip_fraction | 0.243 |
| clip_range | 0.2 |
| entropy_loss | -1.06 |
| explained_variance | 0.999 |
| learning_rate | 0.001 |
| loss | 0.245 |
| n_updates | 180 |
| policy_gradient_loss | -0.00398 |
| std | 0.205 |
| value_loss | 0.226 |
-----------------------------------------
```
### eval/
All `eval/` values are computed by the `EvalCallback`.
- `mean_ep_length`: Mean episode length
- `mean_reward`: Mean episodic reward (during evaluation)
- `success_rate`: Mean success rate during evaluation (1.0 means 100% success), the environment info dict must contain an `is_success` key to compute that value
### rollout/
- `ep_len_mean`: Mean episode length (averaged over `stats_window_size` episodes, 100 by default)
- `ep_rew_mean`: Mean episodic training reward (averaged over `stats_window_size` episodes, 100 by default), a `Monitor` wrapper is required to compute that value (automatically added by `make_vec_env`).
- `exploration_rate`: Current value of the exploration rate when using DQN, it corresponds to the fraction of actions taken randomly (epsilon of the "epsilon-greedy" exploration)
- `success_rate`: Mean success rate during training (averaged over `stats_window_size` episodes, 100 by default), you must pass an extra argument to the `Monitor` wrapper to log that value (`info_keywords=("is_success",)`) and provide `info["is_success"]=True/False` on the final step of the episode
### time/
- `episodes`: Total number of episodes
- `fps`: Number of frames per seconds (includes time taken by gradient update)
- `iterations`: Number of iterations (data collection + policy update for A2C/PPO)
- `time_elapsed`: Time in seconds since the beginning of training
- `total_timesteps`: Total number of timesteps (steps in the environments)
### train/
- `actor_loss`: Current value for the actor loss for off-policy algorithms
- `approx_kl`: approximate mean KL divergence between old and new policy (for PPO), it is an estimation of how much changes happened in the update
- `clip_fraction`: mean fraction of surrogate loss that was clipped (above `clip_range` threshold) for PPO.
- `clip_range`: Current value of the clipping factor for the surrogate loss of PPO
- `critic_loss`: Current value for the critic function loss for off-policy algorithms, usually error between value function output and TD(0), temporal difference estimate
- `ent_coef`: Current value of the entropy coefficient (when using SAC)
- `ent_coef_loss`: Current value of the entropy coefficient loss (when using SAC)
- `entropy_loss`: Mean value of the entropy loss (negative of the average policy entropy)
- `explained_variance`: Fraction of the return variance explained by the value function, see <https://scikit-learn.org/stable/modules/model_evaluation.html#explained-variance-score>
(ev=0 => might as well have predicted zero, ev=1 => perfect prediction, ev\<0 => worse than just predicting zero)
- `learning_rate`: Current learning rate value
- `loss`: Current total loss value
- `n_updates`: Number of gradient updates applied so far
- `policy_gradient_loss`: Current value of the policy gradient loss (its value does not have much meaning)
- `value_loss`: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carlo estimate (or TD(lambda) estimate)
- `std`: Current standard deviation of the noise when using generalized State-Dependent Exploration (gSDE)
```{eval-rst}
.. automodule:: stable_baselines3.common.logger
:members:
```
================================================
FILE: docs/common/monitor.md
================================================
(monitor)=
# Monitor Wrapper
```{eval-rst}
.. automodule:: stable_baselines3.common.monitor
:members:
```
================================================
FILE: docs/common/noise.md
================================================
(noise)=
# Action Noise
```{eval-rst}
.. automodule:: stable_baselines3.common.noise
:members:
```
================================================
FILE: docs/common/utils.md
================================================
(utils)=
# Utils
```{eval-rst}
.. automodule:: stable_baselines3.common.utils
:members:
```
================================================
FILE: docs/conda_env.yml
================================================
name: root
channels:
- pytorch
- conda-forge
dependencies:
- cpuonly=1.0=0
- pip=24.2
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium>=0.29.1,<1.1.0
- cloudpickle
- opencv-python-headless
- pandas
- numpy>=1.20,<3.0
- matplotlib
- sphinx>=5,<10
- sphinx_rtd_theme>=3.0
- sphinx_copybutton
- myst-parser>=4,<6
================================================
FILE: docs/conf.py
================================================
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import datetime
import os
import sys
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
try:
import sphinxcontrib.spelling # noqa: F401
enable_spell_check = True
except ImportError:
enable_spell_check = False
# Try to enable copy button
try:
import sphinx_copybutton # noqa: F401
enable_copy_button = True
except ImportError:
enable_copy_button = False
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt")
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
# -- Project information -----------------------------------------------------
project = "Stable Baselines3"
copyright = f"2021-{datetime.date.today().year}, Stable Baselines3"
author = "Stable Baselines3 Contributors"
# The short X.Y version
version = "master (" + __version__ + " )"
# The full version, including alpha/beta/rc tags
release = __version__
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx.ext.ifconfig",
"sphinx.ext.viewcode",
# 'sphinx.ext.intersphinx',
# 'sphinx.ext.doctest'
"myst_parser",
]
autodoc_typehints = "description"
if enable_spell_check:
extensions.append("sphinxcontrib.spelling")
if enable_copy_button:
extensions.append("sphinx_copybutton")
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = [".rst", ".md"]
# source_suffix = ".rst"
# The master toctree document.
master_doc = "index"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = "sphinx_rtd_theme"
html_logo = "_static/img/logo.png"
def setup(app):
app.add_css_file("css/baselines_theme.css")
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "StableBaselines3doc"
# -- Options for LaTeX output ------------------------------------------------
latex_elements: dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "StableBaselines3.tex", "Stable Baselines3 Documentation", "Stable Baselines3 Contributors", "manual"),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "stablebaselines3", "Stable Baselines3 Documentation", [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(
master_doc,
"StableBaselines3",
"Stable Baselines3 Documentation",
author,
"StableBaselines3",
"One line description of project.",
"Miscellaneous",
),
]
# -- Extension configuration -------------------------------------------------
myst_heading_anchors = 4
# See: https://myst-parser.readthedocs.io/en/latest/syntax/optional.html
myst_enable_extensions = [
# "amsmath",
"attrs_inline",
"colon_fence",
"deflist",
"dollarmath",
"fieldlist",
# "html_admonition",
"html_image",
# "linkify",
# "replacements",
# "smartquotes",
# "strikethrough",
"substitution",
# "tasklist",
]
# Example configuration for intersphinx: refer to the Python standard library.
# intersphinx_mapping = {
# 'python': ('https://docs.python.org/3/', None),
# 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
# 'torch': ('http://pytorch.org/docs/master/', None),
# }
================================================
FILE: docs/guide/algos.md
================================================
# RL Algorithms
This table displays the RL algorithms that are implemented in the Stable Baselines3 project,
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.
| Name | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | Multi Processing |
| ------------------ | ----- | ---------- | --------------- | ------------- | ---------------- |
| ARS [^f1] | ✔️ | ✔️ | ❌ | ❌ | ✔️ |
| A2C | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| CrossQ [^f1] | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| DDPG | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| DQN | ❌ | ✔️ | ❌ | ❌ | ✔️ |
| HER | ✔️ | ✔️ | ❌ | ❌ | ✔️ |
| PPO | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| QR-DQN [^f1] | ❌ | ️✔️ | ❌ | ❌ | ✔️ |
| RecurrentPPO [^f1] | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| SAC | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| TD3 | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| TQC [^f1] | ✔️ | ❌ | ❌ | ❌ | ✔️ |
| TRPO [^f1] | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| Maskable PPO [^f1] | ❌ | ✔️ | ✔️ | ✔️ | ✔️ |
[^f1]: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
:::{note}
`Tuple` observation spaces are not supported by any environment,
however, single-level `Dict` spaces are (cf. {ref}`Examples <examples>`).
:::
Actions `gym.spaces`:
- `Box`: A N-dimensional box that contains every point in the action
space.
- `Discrete`: A list of possible actions, where each timestep only
one of the actions can be used.
- `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
- `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
:::{note}
More algorithms (like QR-DQN or TQC) are implemented in our [contrib repo](sb3_contrib.md)
and in our {ref}`SBX (SB3 + Jax) repo <sbx>` (DroQ, CrossQ, SimBa, ...).
:::
:::{note}
Some logging values (like `ep_rew_mean`, `ep_len_mean`) are only available when using a `Monitor` wrapper
See [Issue #339](https://github.com/hill-a/stable-baselines/issues/339) for more info.
:::
:::{note}
When using off-policy algorithms, [Time Limits](https://arxiv.org/abs/1712.00378) (aka timeouts) are handled
properly (cf. [issue #284](https://github.com/DLR-RM/stable-baselines3/issues/284)).
You can revert to SB3 < 2.1.0 behavior by passing `handle_timeout_termination=False`
via the `replay_buffer_kwargs` argument.
:::
## Reproducibility
Completely reproducible results are not guaranteed across PyTorch releases or different platforms.
Furthermore, results need not be reproducible between CPU and GPU executions, even when using identical seeds.
In order to make computations deterministics, on your specific problem on one specific platform,
you need to pass a `seed` argument at the creation of a model.
If you pass an environment to the model using `set_env()`, then you also need to seed the environment first.
Credit: part of the *Reproducibility* section comes from [PyTorch Documentation](https://pytorch.org/docs/stable/notes/randomness.html)
## Training exceeds `total_timesteps`
When you train an agent using SB3, you pass a `total_timesteps` parameter to the `learn()` method which defines the training budget for the agent (how many interactions with the environment are allowed).
For example:
```python
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1").learn(total_timesteps=1_000)
```
Because of the way the algorithms work, `total_timesteps` is a lower bound (see [issue #1150](https://github.com/DLR-RM/stable-baselines3/issues/1150)).
In the example above, PPO will effectively collect `n_steps * n_envs = 2048 * 1` steps despite `total_timesteps=1_000`
In more details:
- PPO/A2C and derivates collect `n_steps * n_envs` of experience
before performing an update, so if you want to have exactly
`total_timesteps`, you will need to adjust those values
- SAC/DQN/TD3 and other off-policy algorithms collect
`train_freq * n_envs` steps before doing an update (when `train_freq` is in steps and not episodes), so if you want to have exactly `total_timesteps`
you have to adjust these values (`train_freq=4` by default for DQN)
- ARS and other population-based algorithms evaluate the policy for
`n_episodes` with `n_envs`, so unless the number of steps per
episode is fixed, it is not possible to exactly achieve
`total_timesteps`
- when using multiple envs, each call to `env.step()` corresponds to
`n_envs` timesteps, so it is no longer possible to use the
`EvaluationCallback` at an exact timestep
================================================
FILE: docs/guide/callbacks.md
================================================
(callbacks)=
# Callbacks
A callback is a set of functions that will be called at given stages of the training procedure.
You can use callbacks to access internal state of the RL model during training.
It allows one to do monitoring, auto saving, model manipulation, progress bars, ...
## Custom Callback
To build a custom callback, you need to create a class that derives from `BaseCallback`.
This will give you access to events (`_on_training_start`, `_on_step`) and useful variables (like `self.model` for the RL model).
You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see {ref}`Examples <examples>`), and one for logging additional values with Tensorboard (see {ref}`Tensorboard section <tensorboard>`).
```python
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
"""
A custom callback that derives from ``BaseCallback``.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, verbose: int = 0):
super().__init__(verbose)
# Those variables will be accessible in the callback
# (they are defined in the base class)
# The RL model
# self.model = None # type: BaseAlgorithm
# An alias for self.model.get_env(), the environment used for training
# self.training_env # type: VecEnv
# Number of time the callback was called
# self.n_calls = 0 # type: int
# num_timesteps = n_envs * n times env.step() was called
# self.num_timesteps = 0 # type: int
# local and global variables
# self.locals = {} # type: Dict[str, Any]
# self.globals = {} # type: Dict[str, Any]
# The logger object, used to report things in the terminal
# self.logger # type: stable_baselines3.common.logger.Logger
# Sometimes, for event callback, it is useful
# to have access to the parent object
# self.parent = None # type: Optional[BaseCallback]
def _on_training_start(self) -> None:
"""
This method is called before the first rollout starts.
"""
pass
def _on_rollout_start(self) -> None:
"""
A rollout is the collection of environment interaction
using the current policy.
This event is triggered before collecting new samples.
"""
pass
def _on_step(self) -> bool:
"""
This method will be called by the model after each call to `env.step()`.
For child callback (of an `EventCallback`), this will be called
when the event is triggered.
:return: If the callback returns False, training is aborted early.
"""
return True
def _on_rollout_end(self) -> None:
"""
This event is triggered before updating the policy.
"""
pass
def _on_training_end(self) -> None:
"""
This event is triggered before exiting the `learn()` method.
"""
pass
```
:::{note}
`self.num_timesteps` corresponds to the total number of steps taken in the environment, i.e., it is the number of environments multiplied by the number of time `env.step()` was called
For the other algorithms, `self.num_timesteps` is incremented by `n_envs` (number of environments) after each call to `env.step()`
:::
:::{note}
For off-policy algorithms like SAC, DDPG, TD3 or DQN, the notion of `rollout` corresponds to the steps taken in the environment between two updates.
:::
(eventcallback)=
## Event Callback
Compared to Keras, Stable Baselines provides a second type of `BaseCallback`, named `EventCallback` that is meant to trigger events. When an event is triggered, then a child callback is called.
As an example, {ref}`EvalCallback` is an `EventCallback` that will trigger its child callback when there is a new best model.
A child callback is for instance {ref}`StopTrainingOnRewardThreshold <StopTrainingCallback>` that stops the training if the mean reward achieved by the RL model is above a threshold.
:::{note}
We recommend taking a look at the source code of {ref}`EvalCallback` and {ref}`StopTrainingOnRewardThreshold <StopTrainingCallback>` to have a better overview of what can be achieved with this kind of callbacks.
:::
```python
class EventCallback(BaseCallback):
"""
Base class for triggering callback on event.
:param callback: Callback that will be called when an event is triggered.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: BaseCallback, verbose: int = 0):
super().__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
self.callback.parent = self
...
def _on_event(self) -> bool:
return self.callback()
```
## Callback Collection
Stable Baselines provides you with a set of common callbacks for:
- saving the model periodically ({ref}`CheckpointCallback`)
- evaluating the model periodically and saving the best one ({ref}`EvalCallback`)
- chaining callbacks ({ref}`CallbackList`)
- triggering callback on events ({ref}`EventCallback`, {ref}`EveryNTimesteps`)
- logging data every N timesteps ({ref}`LogEveryNTimesteps`)
- stopping the training early based on a reward threshold ({ref}`StopTrainingOnRewardThreshold <StopTrainingCallback>`)
(checkpointcallback)=
### CheckpointCallback
Callback for saving a model every `save_freq` calls to `env.step()`, you must specify a log folder (`save_path`)
and optionally a prefix for the checkpoints (`rl_model` by default).
If you are using this callback to stop and resume training, you may want to optionally save the replay buffer if the
model has one (`save_replay_buffer`, `False` by default).
Additionally, if your environment uses a [VecNormalize](vec_envs.md#vecnormalize) wrapper, you can save the
corresponding statistics using `save_vecnormalize` (`False` by default).
:::{warning}
When using multiple environments, each call to `env.step()` will effectively correspond to `n_envs` steps.
If you want the `save_freq` to be similar when using a different number of environments,
you need to account for it using `save_freq = max(save_freq // n_envs, 1)`.
The same goes for the other callbacks.
:::
```python
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CheckpointCallback
# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(
save_freq=1000,
save_path="./logs/",
name_prefix="rl_model",
save_replay_buffer=True,
save_vecnormalize=True,
)
model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(2000, callback=checkpoint_callback)
```
(evalcallback)=
### EvalCallback
Evaluate periodically the performance of an agent, using a separate test environment.
It will save the best model if `best_model_save_path` folder is specified and save the evaluations results in a NumPy archive (`evaluations.npz`) if `log_path` folder is specified.
:::{note}
You can pass child callbacks via `callback_after_eval` and `callback_on_new_best` arguments. `callback_after_eval` will be triggered after every evaluation, and `callback_on_new_best` will be triggered each time there is a new best model.
:::
:::{warning}
You need to make sure that `eval_env` is wrapped the same way as the training environment, for instance using the `VecTransposeImage` wrapper if you have a channel-last image as input.
The `EvalCallback` class outputs a warning if it is not the case.
:::
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/",
log_path="./logs/", eval_freq=500,
deterministic=True, render=False)
model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(5000, callback=eval_callback)
```
(progressbarcallback)=
### ProgressBarCallback
Display a progress bar with the current progress, elapsed time and estimated remaining time.
This callback is integrated inside SB3 via the `progress_bar` argument of the `learn()` method.
:::{note}
`ProgressBarCallback` callback requires `tqdm` and `rich` packages to be installed. This is done automatically when using `pip install stable-baselines3[extra]`
:::
```python
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import ProgressBarCallback
model = PPO("MlpPolicy", "Pendulum-v1")
# Display progress bar using the progress bar callback
# this is equivalent to model.learn(100_000, callback=ProgressBarCallback())
model.learn(100_000, progress_bar=True)
```
(callbacklist)=
### CallbackList
Class for chaining callbacks, they will be called sequentially.
Alternatively, you can pass directly a list of callbacks to the `learn()` method, it will be converted automatically to a `CallbackList`.
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path="./logs/")
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/best_model",
log_path="./logs/results", eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])
model = SAC("MlpPolicy", "Pendulum-v1")
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)
```
(stoptrainingcallback)=
### StopTrainingOnRewardThreshold
Stop the training once a threshold in episodic reward (mean episode reward over the evaluations) has been reached (i.e., when the model is good enough).
It must be used with the {ref}`EvalCallback` and use the event triggered by a new best model.
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)
```
(everyntimesteps)=
### EveryNTimesteps
An {ref}`EventCallback` that will trigger its child callback every `n_steps` timesteps.
:::{note}
Because of the way `VecEnv` work, `n_steps` is a lower bound between two events when using multiple environments.
:::
```python
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps
# this is equivalent to defining CheckpointCallback(save_freq=500)
# checkpoint_callback will be triggered every 500 steps
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path="./logs/")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
model.learn(20_000, callback=event_callback)
```
(logeveryntimesteps)=
### LogEveryNTimesteps
A callback derived from {ref}`EveryNTimesteps` that will dump the logged data every `n_steps` timesteps.
```python
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import LogEveryNTimesteps
event_callback = LogEveryNTimesteps(n_steps=1_000)
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
# Disable auto-logging by passing `log_interval=None`
model.learn(10_000, callback=event_callback, log_interval=None)
```
(stoptrainingonmaxepisodes)=
### StopTrainingOnMaxEpisodes
Stop the training upon reaching the maximum number of episodes, regardless of the model's `total_timesteps` value.
Also, presumes that, for multiple environments, the desired behavior is that the agent trains on each env for `max_episodes`
and in total for `max_episodes * n_envs` episodes.
:::{note}
For multiple environments, the agent will train for a total of `max_episodes * n_envs` episodes.
However, it can't be guaranteed that this training will occur for an exact number of `max_episodes` per environment.
Thus, there is an assumption that, on average, each environment ran for `max_episodes`.
:::
```python
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
# Stops training when the model reaches the maximum number of episodes
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)
model = A2C("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)
```
(stoptrainingonnomodelimprovement)=
### StopTrainingOnNoModelImprovement
Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive evaluations.
The idea is to save time in experiments when you know that the learning curves are somehow well-behaved and, therefore,
after many evaluations without improvement the learning has probably stabilized.
It must be used with the {ref}`EvalCallback` and use the event triggered after every evaluation.
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Stop training if there is no improvement after more than 3 evaluations
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1)
eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", learning_rate=1e-3, verbose=1)
# Almost infinite number of timesteps, but the training will stop early
# as soon as the number of consecutive evaluations without model
# improvement is greater than 3
model.learn(int(1e10), callback=eval_callback)
```
```{eval-rst}
.. automodule:: stable_baselines3.common.callbacks
:members:
```
================================================
FILE: docs/guide/checking_nan.md
================================================
# Dealing with NaNs and infs
During the training of a model on a given environment, it is possible that the RL model becomes completely
corrupted when a NaN or an inf is given or returned from the RL model.
## How and why?
The issue arises when NaNs or infs do not crash, but simply get propagated through the training,
until all the floating point number converge to NaN or inf. This is in line with the
[IEEE Standard for Floating-Point Arithmetic (IEEE 754)](https://ieeexplore.ieee.org/document/4610935) standard, as it says:
:::{note}
Five possible exceptions can occur:
: - Invalid operation ($\sqrt{-1}$, $\inf \times 1$, $\text{NaN}\ \mathrm{mod}\ 1$, ...) return NaN
- Division by zero:
: - if the operand is not zero ($1/0$, $-2/0$, ...) returns $\pm\inf$
- if the operand is zero ($0/0$) returns signaling NaN
- Overflow (exponent too high to represent) returns $\pm\inf$
- Underflow (exponent too low to represent) returns $0$
- Inexact (not representable exactly in base 2, eg: $1/5$) returns the rounded value (ex: {code}`assert (1/5) * 3 == 0.6000000000000001`)
:::
And of these, only `Division by zero` will signal an exception, the rest will propagate invalid values quietly.
In python, dividing by zero will indeed raise the exception: `ZeroDivisionError: float division by zero`,
but ignores the rest.
The default in numpy, will warn: `RuntimeWarning: invalid value encountered`
but will not halt the code.
## Anomaly detection with PyTorch
To enable NaN detection in PyTorch you can do
```python
import torch as th
th.autograd.set_detect_anomaly(True)
```
## Numpy parameters
Numpy has a convenient way of dealing with invalid value: [numpy.seterr](https://docs.scipy.org/doc/numpy/reference/generated/numpy.seterr.html),
which defines for the python process, how it should handle floating point error.
```python
import numpy as np
np.seterr(all="raise") # define before your code.
print("numpy test:")
a = np.float64(1.0)
b = np.float64(0.0)
val = a / b # this will now raise an exception instead of a warning.
print(val)
```
but this will also avoid overflow issues on floating point numbers:
```python
import numpy as np
np.seterr(all="raise") # define before your code.
print("numpy overflow test:")
a = np.float64(10)
b = np.float64(1000)
val = a ** b # this will now raise an exception
print(val)
```
but will not avoid the propagation issues:
```python
import numpy as np
np.seterr(all="raise") # define before your code.
print("numpy propagation test:")
a = np.float64("NaN")
b = np.float64(1.0)
val = a + b # this will neither warn nor raise anything
print(val)
```
## VecCheckNan Wrapper
In order to find when and from where the invalid value originated from, stable-baselines3 comes with a `VecCheckNan` wrapper.
It will monitor the actions, observations, and rewards, indicating what action or observation caused it and from what.
```python
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
metadata = {"render.modes": ["human"]}
def __init__(self):
super(NanAndInfEnv, self).__init__()
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
def step(self, _action):
randf = np.random.rand()
if randf > 0.99:
obs = float("NaN")
elif randf > 0.98:
obs = float("inf")
else:
obs = randf
return [obs], 0.0, False, {}
def reset(self):
return [0.0]
def render(self, close=False):
pass
# Create environment
env = DummyVecEnv([lambda: NanAndInfEnv()])
env = VecCheckNan(env, raise_exception=True)
# Instantiate the agent
model = PPO("MlpPolicy", env)
# Train the agent
model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.
```
## RL Model hyperparameters
Depending on your hyperparameters, NaN can occur much more often.
A great example of this: <https://github.com/hill-a/stable-baselines/issues/340>
Be aware, the hyperparameters given by default seem to work in most cases,
however your environment might not play nice with them.
If this is the case, try to read up on the effect each hyperparameter has on the model,
so that you can try and tune them to get a stable model. Alternatively, you can try automatic hyperparameter tuning (included in the rl zoo).
## Missing values from datasets
If your environment is generated from an external dataset, do not forget to make sure your dataset does not contain NaNs.
As some datasets will sometimes fill missing values with NaNs as a surrogate value.
Here is some reading material about finding NaNs: <https://pandas.pydata.org/pandas-docs/stable/user_guide/missing_data.html>
And filling the missing values with something else (imputation): <https://towardsdatascience.com/how-to-handle-missing-data-8646b18db0d4>
================================================
FILE: docs/guide/custom_env.md
================================================
(custom-env)=
# Using Custom Environments
To use the RL baselines with custom environments, they just need to follow the *gymnasium* [interface](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#sphx-glr-tutorials-gymnasium-basics-environment-creation-py).
That is to say, your environment must implement the following methods (and inherits from Gym Class):
:::{note}
If you are using images as input, the observation must be of type `np.uint8` and be within a space `Box` bounded by [0, 255] (`Box(low=0, high=255, shape=(<your image shape>)`).
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1], i.e. `Box(low=0, high=1)`) when using CNN policies.
Images can be either channel-first or channel-last.
If you want to use `CnnPolicy` or `MultiInputPolicy` with image-like observation (3D tensor) that are already normalized, you must pass `normalize_images=False`
to the policy (using `policy_kwargs` parameter, `policy_kwargs=dict(normalize_images=False)`)
and make sure your image is in the **channel-first** format.
:::
:::{note}
Although SB3 supports both channel-last and channel-first images as input, we recommend using the channel-first convention when possible.
Under the hood, when a channel-last image is passed, SB3 uses a `VecTransposeImage` wrapper to re-order the channels.
:::
:::{note}
SB3 doesn't support `Discrete` and `MultiDiscrete` spaces with `start!=0`. However, you can update your environment or use a wrapper to make your env compatible with SB3:
```python
import gymnasium as gym
class ShiftWrapper(gym.Wrapper):
"""Allow to use Discrete() action spaces with start!=0"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.Discrete)
self.action_space = gym.spaces.Discrete(env.action_space.n, start=0)
def step(self, action: int):
return self.env.step(action + self.env.action_space.start)
```
:::
:::{note}
SB3 doesn't support `MultiDiscrete` spaces with multi-dimensional arrays. However, you can update your environment or use a wrapper to make your env compatible with SB3:
```python
import numpy as np
import gymnasium as gym
class ReshapeWrapper(gym.Wrapper):
"""Allow to use MultiDiscrete() action spaces with len(nvec.shape) > 1:"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.MultiDiscrete)
self.original_shape = env.action_space.nvec.shape
self.action_space = gym.spaces.MultiDiscrete(env.action_space.nvec.flatten())
def step(self, action: np.ndarray):
return self.env.step(action.reshape(self.original_shape))
```
:::
```python
import gymnasium as gym
import numpy as np
from gymnasium import spaces
class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface."""
metadata = {"render_modes": ["human"], "render_fps": 30}
def __init__(self, arg1, arg2, ...):
super().__init__()
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
# Example for using image as input (channel-first; channel-last also works):
self.observation_space = spaces.Box(low=0, high=255,
shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8)
def step(self, action):
...
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
...
return observation, info
def render(self):
...
def close(self):
...
```
Then you can define and train a RL agent with:
```python
# Instantiate the env
env = CustomEnv(arg1, ...)
# Define and Train the agent
model = A2C("CnnPolicy", env).learn(total_timesteps=1000)
```
To check that your environment follows the Gym interface that SB3 supports, please use:
```python
from stable_baselines3.common.env_checker import check_env
env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)
```
Gymnasium also have its own [env checker](https://gymnasium.farama.org/api/utils/#gymnasium.utils.env_checker.check_env) but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
We have created a [colab notebook](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb) for a concrete example of creating a custom environment along with an example of using it with Stable-Baselines3 interface.
Alternatively, you may look at Gymnasium [built-in environments](https://gymnasium.farama.org).
Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use `gym.make()` to instantiate the env):
```python
from gymnasium.envs.registration import register
# Example for the CartPole environment
register(
# unique identifier for the env `name-version`
id="CartPole-v1",
# path to the class for creating the env
# Note: entry_point also accept a class as input (and not only a string)
entry_point="gym.envs.classic_control:CartPoleEnv",
# Max number of steps per episode, using a `TimeLimitWrapper`
max_episode_steps=500,
)
```
In the project, for testing purposes, we use a custom environment named `IdentityEnv`
defined [in this file](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/envs/identity_env.py).
An example of how to use it can be found [here](https://github.com/DLR-RM/stable-baselines3/blob/master/tests/test_identity.py).
================================================
FILE: docs/guide/custom_policy.md
================================================
(custom-policy)=
# Policy Networks
Stable Baselines3 provides policy networks for images (CnnPolicies),
other type of input features (MlpPolicies) and multiple different inputs (MultiInputPolicies).
:::{warning}
For A2C and PPO, continuous actions are clipped during training and testing
(to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a `tanh()` transformation,
which handles bounds more correctly.
:::
## SB3 Policy
SB3 networks are separated into two main parts (see figure below):
- A features extractor (usually shared between actor and critic when applicable, to save computation)
whose role is to extract features (i.e. convert to a feature vector) from high-dimensional observations, for instance, a CNN that extracts features from images.
This is the `features_extractor_class` parameter. You can change the default parameters of that features extractor
by passing a `features_extractor_kwargs` parameter.
- A (fully-connected) network that maps the features to actions/value. Its architecture is controlled by the `net_arch` parameter.
:::{note}
All observations are first pre-processed (e.g. images are normalized, discrete obs are converted to one-hot vectors, ...) before being fed to the features extractor.
In the case of vector observations, the features extractor is just a `Flatten` layer.
:::
```{image} ../_static/img/net_arch.png
```
SB3 policies are usually composed of several networks (actor/critic networks + target networks when applicable) together
with the associated optimizers.
Each of these network have a features extractor followed by a fully-connected network.
:::{note}
When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology.
In SB3, "policy" refers to the class that handles all the networks useful for training,
so not only the network used to predict actions (the "learned controller").
:::
```{image} ../_static/img/sb3_policy.png
```
## Default Network Architecture
The default network architecture used by SB3 depends on the algorithm and the observation space.
You can visualize the architecture by printing `model.policy` (see [issue #329](https://github.com/DLR-RM/stable-baselines3/issues/329)).
For 1D observation space, a 2 layers fully connected net is used with:
- 64 units (per layer) for PPO/A2C/DQN
- 256 units for SAC
- [400, 300] units for TD3/DDPG (values are taken from the original TD3 paper)
For image observation spaces, the "Nature CNN" (see code for more details) is used for feature extraction, and SAC/TD3 also keeps the same fully connected network after it.
The other algorithms only have a linear layer after the CNN.
The CNN is shared between actor and critic for A2C/PPO (on-policy algorithms) to reduce computation.
Off-policy algorithms (TD3, DDPG, SAC, ...) have separate feature extractors: one for the actor and one for the critic, since the best performance is obtained with this configuration.
For mixed observations (dictionary observations), the two architectures from above are used, i.e., CNN for images and then two layers fully-connected network
(with a smaller output size for the CNN).
## Custom Network Architecture
One way of customising the policy network architecture is to pass arguments when creating the model,
using `policy_kwargs` parameter:
:::{note}
An extra linear layer will be added on top of the layers specified in `net_arch`, in order to have the right output dimensions and activation functions (e.g. Softmax for discrete actions).
In the following example, as CartPole's action space has a dimension of 2, the final dimensions of the `net_arch`'s layers will be:
```none
obs
<4>
/ \
<32> <32>
| |
<32> <32>
| |
<2> <1>
action value
```
:::
```python
import gymnasium as gym
import torch as th
from stable_baselines3 import PPO
# Custom actor (pi) and value function (vf) networks
# of two layers of size 32 each with Relu activation function
# Note: an extra linear layer will be added on top of the pi and the vf nets, respectively
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=dict(pi=[32, 32], vf=[32, 32]))
# Create the agent
model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
# Retrieve the environment
env = model.get_env()
# Train the agent
model.learn(total_timesteps=20_000)
# Save the agent
model.save("ppo_cartpole")
del model
# the policy_kwargs are automatically loaded
model = PPO.load("ppo_cartpole", env=env)
```
## Custom Feature Extractor
If you want to have a custom features extractor (e.g. custom CNN when using images), you can define class
that derives from `BaseFeaturesExtractor` and then pass it to the model when training.
:::{note}
For on-policy algorithms, the features extractor is shared by default between the actor and the critic to save computation (when applicable).
However, this can be changed setting `share_features_extractor=False` in the
`policy_kwargs` (both for on-policy and off-policy algorithms).
:::
```python
import torch as th
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCNN(BaseFeaturesExtractor):
"""
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=128),
)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
model.learn(1000)
```
## Multiple Inputs and Dictionary Observations
Stable Baselines3 supports handling of multiple inputs by using `Dict` Gym space. This can be done using
`MultiInputPolicy`, which by default uses the `CombinedExtractor` features extractor to turn multiple
inputs into a single vector, handled by the `net_arch` network.
By default, `CombinedExtractor` processes multiple inputs as follows:
1. If input is an image (automatically detected, see `common.preprocessing.is_image_space`), process image with Nature Atari CNN network and
output a latent vector of size `256`.
2. If input is not an image, flatten it (no layers).
3. Concatenate all previous vectors into one long vector and pass it to policy.
Much like above, you can define custom features extractors. The following example assumes the environment has two keys in the
observation space dictionary: "image" is a (1,H,W) image (channel first), and "vector" is a (D,) dimensional vector. We process "image" with a simple
downsampling and "vector" with a single linear layer.
```python
import gymnasium as gym
import torch as th
from torch import nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCombinedExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Dict):
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.__init__ before adding modules
super().__init__(observation_space, features_dim=1)
extractors = {}
total_concat_size = 0
# We need to know size of the output of this extractor,
# so go over all the spaces and compute output feature sizes
for key, subspace in observation_space.spaces.items():
if key == "image":
# We will just downsample one channel of the image by 4x4 and flatten.
# Assume the image is single-channel (subspace.shape[0] == 0)
extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten())
total_concat_size += subspace.shape[1] // 4 * subspace.shape[2] // 4
elif key == "vector":
# Run through a simple MLP
extractors[key] = nn.Linear(subspace.shape[0], 16)
total_concat_size += 16
self.extractors = nn.ModuleDict(extractors)
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations) -> th.Tensor:
encoded_tensor_list = []
# self.extractors contain nn.Modules that do all the processing.
for key, extractor in self.extractors.items():
encoded_tensor_list.append(extractor(observations[key]))
# Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
return th.cat(encoded_tensor_list, dim=1)
```
## On-Policy Algorithms
### Custom Networks
If you need a network architecture that is different for the actor and the critic when using `PPO`, `A2C` or `TRPO`,
you can pass a dictionary of the following structure: `dict(pi=[<actor network architecture>], vf=[<critic network architecture>])`.
For example, if you want a different architecture for the actor (aka `pi`) and the critic (value-function aka `vf`) networks,
then you can specify `net_arch=dict(pi=[32, 32], vf=[64, 64])`.
Otherwise, to have actor and critic that share the same network architecture,
you only need to specify `net_arch=[128, 128]` (here, two hidden layers of 128 units each, this is equivalent to `net_arch=dict(pi=[128, 128], vf=[128, 128])`).
If shared layers are needed, you need to implement a custom policy network (see [advanced example below](#advanced-example)).
#### Examples
Same architecture for actor and critic with two layers of size 128: `net_arch=[128, 128]`
```none
obs
/ \
<128> <128>
| |
<128> <128>
| |
action value
```
Different architectures for actor and critic: `net_arch=dict(pi=[32, 32], vf=[64, 64])`
```none
obs
/ \
<32> <64>
| |
<32> <64>
| |
action value
```
#### Advanced Example
If your task requires even more granular control over the policy/value architecture, you can redefine the policy directly:
```python
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from gymnasium import spaces
import torch as th
from torch import nn
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the features extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
"""
def __init__(
self,
feature_dim: int,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super().__init__()
# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
)
# Value network
self.value_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.forward_actor(features), self.forward_critic(features)
def forward_actor(self, features: th.Tensor) -> th.Tensor:
return self.policy_net(features)
def forward_critic(self, features: th.Tensor) -> th.Tensor:
return self.value_net(features)
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Callable[[float], float],
*args,
**kwargs,
):
# Disable orthogonal initialization
kwargs["ortho_init"] = False
super().__init__(
observation_space,
action_space,
lr_schedule,
# Pass remaining arguments to base class
*args,
**kwargs,
)
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(self.features_dim)
model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
model.learn(5000)
```
## Off-Policy Algorithms
If you need a network architecture that is different for the actor and the critic when using `SAC`, `DDPG`, `TQC` or `TD3`,
you can pass a dictionary of the following structure: `dict(pi=[<actor network architecture>], qf=[<critic network architecture>])`.
For example, if you want a different architecture for the actor (aka `pi`) and the critic (Q-function aka `qf`) networks,
then you can specify `net_arch=dict(pi=[64, 64], qf=[400, 300])`.
Otherwise, to have actor and critic that share the same network architecture,
you only need to specify `net_arch=[256, 256]` (here, two hidden layers of 256 units each).
:::{note}
For advanced customization of off-policy algorithms policies, please take a look at the code.
A good understanding of the algorithm used is required, see discussion in [issue #425](https://github.com/DLR-RM/stable-baselines3/issues/425)
:::
```python
from stable_baselines3 import SAC
# Custom actor architecture with two layers of 64 units each
# Custom critic architecture with two layers of 400 and 300 units
policy_kwargs = dict(net_arch=dict(pi=[64, 64], qf=[400, 300]))
# Create the agent
model = SAC("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(5000)
```
================================================
FILE: docs/guide/developer.md
================================================
(developer)=
# Developer Guide
This guide is meant for those who want to understand the internals and the design choices of Stable-Baselines3.
At first, you should read the two issues where the design choices were discussed:
- <https://github.com/hill-a/stable-baselines/issues/576>
- <https://github.com/hill-a/stable-baselines/issues/733>
The library is not meant to be modular, although inheritance is used to reduce code duplication.
## Algorithms Structure
Each algorithm (on-policy and off-policy ones) follows a common structure.
Policy contains code for acting in the environment, and algorithm updates this policy.
There is one folder per algorithm, and in that folder there is the algorithm and the policy definition (`policies.py`).
Each algorithm has two main methods:
- `.collect_rollouts()` which defines how new samples are collected, usually inherited from the base class. Those samples are then stored in a `RolloutBuffer` (discarded after the gradient update) or `ReplayBuffer`
- `.train()` which updates the parameters using samples from the buffer
```{image} ../_static/img/sb3_loop.png
```
## Where to start?
The first thing you need to read and understand are the base classes in the `common/` folder:
- `BaseAlgorithm` in `base_class.py` which defines how an RL class should look like.
It contains also all the "glue code" for saving/loading and the common operations (wrapping environments)
- `BasePolicy` in `policies.py` which defines how a policy class should look like.
It contains also all the magic for the `.predict()` method, to handle as many spaces/cases as possible
- `OffPolicyAlgorithm` in `off_policy_algorithm.py` that contains the implementation of `collect_rollouts()` for the off-policy algorithms,
and similarly `OnPolicyAlgorithm` in `on_policy_algorithm.py`.
All the environments handled internally are assumed to be `VecEnv` (`gym.Env` are automatically wrapped).
## Pre-Processing
To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation).
Most of the code for pre-processing is in `common/preprocessing.py` and `common/policies.py`.
For images, environment is automatically wrapped with `VecTransposeImage` if observations are detected to be images with
channel-last convention to transform it to PyTorch's channel-first convention.
## Policy Structure
When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology.
In SB3, "policy" refers to the class that handles all the networks useful for training,
so not only the network used to predict actions (the "learned controller").
For instance, the `TD3` policy contains the actor, the critic and the target networks.
To avoid the hassle of importing specific policy classes for specific algorithm (e.g. both A2C and PPO use `ActorCriticPolicy`),
SB3 uses names like "MlpPolicy" and "CnnPolicy" to refer policies using small feed-forward networks or convolutional networks,
respectively. Importing `[algorithm]/policies.py` registers an appropriate policy for that algorithm under those names.
## Probability distributions
When needed, the policies handle the different probability distributions.
All distributions are located in `common/distributions.py` and follow the same interface.
Each distribution corresponds to a type of action space (e.g. `Categorical` is the one used for discrete actions.
For continuous actions, we can use multiple distributions ("DiagGaussian", "SquashedGaussian" or "StateDependentDistribution")
## State-Dependent Exploration
State-Dependent Exploration (SDE) is a type of exploration that allows to use RL directly on real robots,
that was the starting point for the Stable-Baselines3 library.
I (@araffin) published a paper about a generalized version of SDE (the one implemented in SB3): <https://arxiv.org/abs/2005.05719>
## Misc
The rest of the `common/` is composed of helpers (e.g. evaluation helpers) or basic components (like the callbacks).
The `type_aliases.py` file contains common type hint aliases like `GymStepReturn`.
Et voilà?
After reading this guide and the mentioned files, you should be now able to understand the design logic behind the library ;)
================================================
FILE: docs/guide/examples.md
================================================
---
myst:
substitutions:
colab: |-
```{image} ../_static/img/colab.svg
```
---
(examples)=
# Examples
:::{note}
These examples are only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in the RL Zoo [repository](https://github.com/DLR-RM/rl-baselines3-zoo).
:::
## Try it online with Colab Notebooks!
All the following examples can be executed online using Google colab {{ colab }}
notebooks:
- [Full Tutorial](https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3)
- [All Notebooks](https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3)
- [Getting Started](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb)
- [Training, Saving, Loading](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb)
- [Multiprocessing]
- [Monitor Training and Plotting]
- [Atari Games]
- [RL Baselines zoo]
- [PyBullet]
- [Hindsight Experience Replay]
- [Advanced Saving and Loading]
## Basic Usage: Training, Saving, Loading
In the following example, we will train, save and load a DQN model on the Lunar Lander environment.
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb
```
:::{figure} https://cdn-images-1.medium.com/max/960/1*f4VZPKOI0PYNWiwt0la0Rg.gif
Lunar Lander Environment
:::
:::{note}
LunarLander requires the python package `box2d`.
You can install it using `apt install swig` and then `pip install box2d box2d-kengz`
:::
:::{warning}
`load` method re-creates the model from scratch and should be called on the Algorithm without instantiating it first,
e.g. `model = DQN.load("dqn_lunar", env=env)` instead of `model = DQN(env=env)` followed by `model.load("dqn_lunar")`. The latter **will not work** as `load` is not an in-place operation.
If you want to load parameters without re-creating the model, e.g. to evaluate the same model
with multiple different sets of parameters, consider using `set_parameters` instead.
:::
```python
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make("LunarLander-v3", render_mode="rgb_array")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent and display a progress bar
model.learn(total_timesteps=int(2e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")
del model # delete trained model to demonstrate loading
# Load the trained agent
# NOTE: if you have loading issue, you can pass `print_system_info=True`
# to compare the system on which the model was trained vs the current one
# model = DQN.load("dqn_lunar", env=env, print_system_info=True)
model = DQN.load("dqn_lunar", env=env)
# Evaluate the agent
# NOTE: If you use wrappers with your environment that modify rewards,
# this will be reflected here. To evaluate with original rewards,
# wrap environment in a "Monitor" wrapper before other wrappers.
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# Enjoy trained agent
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
```
## Multiprocessing: Unleashing the Power of Vectorized Environments
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
```
:::{figure} https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif
CartPole Environment
:::
```python
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed
def make_env(env_id: str, rank: int, seed: int = 0):
"""
Utility function for multiprocessed env.
:param env_id: the environment ID
:param num_env: the number of environments you wish to have in subprocesses
:param seed: the initial seed for RNG
:param rank: index of the subprocess
"""
def _init():
env = gym.make(env_id, render_mode="human")
env.reset(seed=seed + rank)
return env
set_random_seed(seed)
return _init
if __name__ == "__main__":
env_id = "CartPole-v1"
num_cpu = 4 # Number of processes to use
# Create the vectorized environment
vec_env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
# Stable Baselines provides you with make_vec_env() helper
# which does exactly the previous steps for you.
# You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25_000)
obs = vec_env.reset()
for _ in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render()
```
## Multiprocessing with off-policy algorithms
:::{warning}
When using multiple environments with off-policy algorithms, you should update the `gradient_steps`
parameter too. Set it to `gradient_steps=-1` to perform as many gradient steps as transitions collected.
There is usually a compromise between wall-clock time and sample efficiency,
see this [example in PR #439](https://github.com/DLR-RM/stable-baselines3/pull/439#issuecomment-961796799)
:::
```python
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0)
# We collect 4 transitions per call to `env.step()`
# and performs 2 gradient steps per call to `env.step()`
# if gradient_steps=-1, then we would do 4 gradients steps per call to `env.step()`
model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1)
model.learn(total_timesteps=10_000)
```
## Dict Observations
You can use environments with dictionary observation spaces. This is useful in the case where one can't directly
concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles).
Stable Baselines3 provides `SimpleMultiObsEnv` as an example of this kind of setting.
The environment is a simple grid world, but the observations for each cell come in the form of dictionaries.
These dictionaries are randomly initialized on the creation of the environment and contain a vector observation and an image observation.
```python
from stable_baselines3 import PPO
from stable_baselines3.common.envs import SimpleMultiObsEnv
# Stable Baselines provides SimpleMultiObsEnv as an example environment with Dict observations
env = SimpleMultiObsEnv(random_start=False)
model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=100_000)
```
## Callbacks: Monitoring Training
:::{note}
We recommend reading the [Callback section](callbacks.md)
:::
You can define a custom callback function that will be called inside the agent.
This could be useful when you want to monitor training, for instance display live
learning curves in Tensorboard or save the best agent.
If your callback returns False, training is aborted early.
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
```
```python
import os
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import TD3
from stable_baselines3.common import results_plotter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import BaseCallback
class SaveOnBestTrainingRewardCallback(BaseCallback):
"""
Callback for saving a model (the check is done every ``check_freq`` steps)
based on the training reward (in practice, we recommend using ``EvalCallback``).
:param check_freq:
:param log_dir: Path to the folder where the model will be saved.
It must contain the file created by the ``Monitor`` wrapper.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
super().__init__(verbose)
self.check_freq = check_freq
self.log_dir = log_dir
self.save_path = os.path.join(log_dir, "best_model")
self.best_mean_reward = -np.inf
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# Retrieve training reward
x, y = ts2xy(load_results(self.log_dir), "timesteps")
if len(x) > 0:
# Mean training reward over the last 100 episodes
mean_reward = np.mean(y[-100:])
if self.verbose >= 1:
print(f"Num timesteps: {self.num_timesteps}")
print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
# New best model, you could save the agent here
if mean_reward > self.best_mean_reward:
self.best_mean_reward = mean_reward
# Example for saving best model
if self.verbose >= 1:
print(f"Saving new best model to {self.save_path}")
self.model.save(self.save_path)
return True
# Create log dir
log_dir = "tmp/"
os.makedirs(log_dir, exist_ok=True)
# Create and wrap the environment
env = gym.make("LunarLanderContinuous-v3")
env = Monitor(env, log_dir)
# Add some action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
# Because we use parameter noise, we should use a MlpPolicy with layer normalization
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=0)
# Create the callback: check every 1000 steps
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
# Train the agent
timesteps = 1e5
model.learn(total_timesteps=int(timesteps), callback=callback)
plot_results([log_dir], timesteps, results_plotter.X_TIMESTEPS, "TD3 LunarLander")
plt.show()
```
## Callbacks: Evaluate Agent Performance
To periodically evaluate an agent's performance on a separate test environment, use `EvalCallback`.
You can control the evaluation frequency with `eval_freq` to monitor your agent's progress during training.
```python
import os
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
env_id = "Pendulum-v1"
n_training_envs = 1
n_eval_envs = 5
# Create log dir where evaluation results will be saved
eval_log_dir = "./eval_logs/"
os.makedirs(eval_log_dir, exist_ok=True)
# Initialize a vectorized training environment with default parameters
train_env = make_vec_env(env_id, n_envs=n_training_envs, seed=0)
# Separate evaluation env, with different parameters passed via env_kwargs
# Eval environments can be vectorized to speed up evaluation.
eval_env = make_vec_env(env_id, n_envs=n_eval_envs, seed=0,
env_kwargs={'g':0.7})
# Create callback that evaluates agent for 5 episodes every 500 training environment steps.
# When using multiple training environments, agent will be evaluated every
# eval_freq calls to train_env.step(), thus it will be evaluated every
# (eval_freq * n_envs) training steps. See EvalCallback doc for more information.
eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir,
log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1),
n_eval_episodes=5, deterministic=True,
render=False)
model = SAC("MlpPolicy", train_env)
model.learn(5000, callback=eval_callback)
```
## Atari Games
:::{figure} ../_static/img/breakout.gif
Trained A2C agent on Breakout
:::
:::{figure} https://cdn-images-1.medium.com/max/960/1*UHYJE7lF8IDZS_U5SsAFUQ.gif
Pong Environment
:::
Training a RL agent on Atari games is straightforward thanks to `make_atari_env` helper function.
It will do [all the preprocessing](https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/)
and multiprocessing for you. To install the Atari environments, run the command `pip install gymnasium[atari,accept-rom-license]` to install the Atari environments and ROMs, or install Stable Baselines3 with `pip install stable-baselines3[extra]` to install this and other optional dependencies.
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
```
:::{note}
When working with Atari environments, be aware that the default `terminal_on_life_loss=True` behavior
can cause `env.reset()` to perform a no-op step instead of truly resetting the environment when
the episode ends due to a life loss (not game over, see [issue #666](https://github.com/DLR-RM/stable-baselines3/issues/666)).
To ensure `reset()` always resets the environment, use:
```python
from stable_baselines3.common.env_util import make_atari_env
import ale_py
env = make_atari_env(
"BreakoutNoFrameskip-v4",
n_envs=1,
wrapper_kwargs=dict(terminal_on_life_loss=False)
)
```
:::
```python
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import A2C
import ale_py
# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=4 => 4 environments)
vec_env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0)
# Frame-stacking with 4 frames
vec_env = VecFrameStack(vec_env, n_stack=4)
model = A2C("CnnPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25_000)
obs = vec_env.reset()
while True:
action, _states = model.predict(obs, deterministic=False)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
```
## PyBullet: Normalizing input features
Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled, but other types of input are not),
for instance when training on [PyBullet](https://github.com/bulletphysics/bullet3/) environments.
For this, there is a wrapper `VecNormalize` that will compute a running average and standard deviation of the input features (it can do the same for rewards).
:::{note}
you need to install pybullet envs with `pip install pybullet_envs_gymnasium`
:::
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
```
```python
from pathlib import Path
import pybullet_envs_gymnasium
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import PPO
# Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4"
vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1)
# Automatically normalize the input features and reward
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
model = PPO("MlpPolicy", vec_env)
model.learn(total_timesteps=2000)
# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = Path("/tmp/")
model.save(log_dir / "ppo_halfcheetah")
stats_path = log_dir / "vec_normalize.pkl"
vec_env.save(stats_path)
# To demonstrate loading
del model, vec_env
# Load the saved statistics
vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1)
vec_env = VecNormalize.load(stats_path, vec_env)
# do not update them at test time
vec_env.training = False
# reward normalization is not needed at test time
vec_env.norm_reward = False
# Load the agent
model = PPO.load(log_dir / "ppo_halfcheetah", env=vec_env)
```
## Hindsight Experience Replay (HER)
For this example, we are using [Highway-Env](https://github.com/eleurent/highway-env) by [@eleurent](https://github.com/eleurent).
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
```
:::{figure} https://raw.githubusercontent.com/eleurent/highway-env/gh-media/docs/media/parking-env.gif
The highway-parking-v0 environment.
:::
The parking env is a goal-conditioned continuous control task, in which the vehicle must park in a given space with the appropriate heading.
:::{note}
The hyperparameters in the following example were optimized for that environment.
:::
```python
import gymnasium as gym
import highway_env
import numpy as np
from stable_baselines3 import HerReplayBuffer, SAC, DDPG, TD3
from stable_baselines3.common.noise import NormalActionNoise
env = gym.make("parking-v0")
# Create 4 artificial transitions per real transition
n_sampled_goal = 4
# SAC hyperparams:
model = SAC(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=n_sampled_goal,
goal_selection_strategy="future",
),
verbose=1,
buffer_size=int(1e6),
learning_rate=1e-3,
gamma=0.95,
batch_size=256,
policy_kwargs=dict(net_arch=[256, 256, 256]),
)
model.learn(int(2e5))
model.save("her_sac_highway")
# Load saved model
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
env = gym.make("parking-v0", render_mode="human") # Change the render mode
model = SAC.load("her_sac_highway", env=env)
obs, info = env.reset()
# Evaluate the agent
episode_reward = 0
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
if terminated or truncated or info.get("is_success", False):
print("Reward:", episode_reward, "Success?", info.get("is_success", False))
episode_reward = 0.0
obs, info = env.reset()
```
## Learning Rate Schedule
All algorithms allow you to pass a learning rate schedule that takes as input the current progress remaining (from 1 to 0).
`PPO`'s `` clip_range` `` parameter also accepts such schedule.
The [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) already includes
linear and constant schedules.
```python
from typing import Callable
from stable_baselines3 import PPO
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""
def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value
return func
# Initial learning rate of 0.001
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1)
model.learn(total_timesteps=20_000)
# By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets.
# progress_remaining = 1.0 - (num_timesteps / total_timesteps)
model.learn(total_timesteps=10_000, reset_num_timesteps=True)
```
## Advanced Saving and Loading
In this example, we show how to use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
By default, the replay buffer is not saved when calling `model.save()`, in order to save space on the disk (a replay buffer can be up to several GB when using images).
However, SB3 provides a `save_replay_buffer()` and `load_replay_buffer()` method to save it separately.
:::{note}
For training model after loading it, we recommend loading the replay buffer to ensure stable learning (for off-policy algorithms).
You also need to pass `reset_num_timesteps=True` to `learn` function which initializes the environment
and agent for training if a new environment was created since saving the model.
:::
```{image} ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb
```
```python
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.sac.policies import MlpPolicy
# Create the model and the training environment
model = SAC("MlpPolicy", "Pendulum-v1", verbose=1,
learning_rate=1e-3)
# train the model
model.learn(total_timesteps=6000)
# save the model
model.save("sac_pendulum")
# the saved model does not contain the replay buffer
loaded_model = SAC.load("sac_pendulum")
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
# now save the replay buffer too
model.save_replay_buffer("sac_replay_buffer")
# load it into the loaded_model
loaded_model.load_replay_buffer("sac_replay_buffer")
# now the loaded replay is not empty anymore
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
# Save the policy independently from the model
# Note: if you don't save the complete model with `model.save()`
# you cannot continue training afterward
policy = model.policy
policy.save("sac_policy_pendulum")
# Retrieve the environment
env = model.get_env()
# Evaluate the policy
mean_reward, std_reward = evaluate_policy(policy, env, n_eval_episodes=10, deterministic=True)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
# Load the policy independently from the model
saved_policy = MlpPolicy.load("sac_policy_pendulum")
# Evaluate the loaded policy
mean_reward, std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
```
## Accessing and modifying model parameters
You can access model's parameters via `set_parameters` and `get_parameters` functions,
or via `model.policy.state_dict()` (and `load_state_dict()`),
which use dictionaries that map variable names to PyTorch tensors.
These functions are useful when you need to e.g. evaluate large set of models with same network structure,
visualize different layers of the network or modify parameters manually.
Policies also offers a simple way to save/load weights as a NumPy vector, using `parameters_to_vector()`
and `load_from_vector()` method.
Following example demonstrates reading parameters, modifying some of them and loading them to model
by implementing [evolution strategy (es)](http://blog.otoro.net/2017/10/29/visual-evolution-strategies/)
for solving the `CartPole-v1` environment. The initial guess for parameters is obtained by running
A2C policy gradient updates on the model.
```python
from typing import Dict
import gymnasium as gym
import numpy as np
import torch as th
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
def mutate(params: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]:
"""Mutate parameters by adding normal noise to them"""
return dict((name, param + th.randn_like(param)) for name, param in params.items())
# Create policy with a small network
model = A2C(
"MlpPolicy",
"CartPole-v1",
ent_coef=0.0,
policy_kwargs={"net_arch": [32]},
seed=0,
learning_rate=0.05,
)
# Use traditional actor-critic policy gradient updates to
# find good initial parameters
model.learn(total_timesteps=10_000)
# Include only variables with "policy", "action" (policy) or "shared_net" (shared layers)
# in their name: only these ones affect the action.
# NOTE: you can retrieve those parameters using model.get_parameters() too
mean_params = dict(
(key, value)
for key, value in model.policy.state_dict().items()
if ("policy" in key or "shared_net" in key or "action" in key)
)
# population size of 50 individuals
pop_size = 50
# Keep top 10%
n_elite = pop_size // 10
# Retrieve the environment
vec_env = model.get_env()
for iteration in range(10):
# Create population of candidates and evaluate them
population = []
for population_i in range(pop_size):
candidate = mutate(mean_params)
# Load new policy parameters to agent.
# Tell function that it should only update parameters
# we give it (policy parameters)
model.policy.load_state_dict(candidate, strict=False)
# Evaluate the candidate
fitness, _ = evaluate_policy(model, vec_env)
population.append((candidate, fitness))
# Take top 10% and use average over their parameters as next mean parameter
top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite]
mean_params = dict(
(
name,
th.stack([candidate[0][name] for candidate in top_candidates]).mean(dim=0),
)
for name in mean_params.keys()
)
mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / n_elite
print(f"Iteration {iteration + 1:<3} Mean top fitness: {mean_fitness:.2f}")
print(f"Best fitness: {top_candidates[0][1]:.2f}")
```
## SB3 with Isaac Lab, Brax, Procgen, EnvPool
Some massively parallel simulations such as [EnvPool](https://github.com/sail-sg/envpool), [Isaac Lab](https://github.com/isaac-sim/IsaacLab), [Brax](https://github.com/google/brax) or [ProcGen](https://github.com/Farama-Foundation/Procgen2) already produce a vectorized environment to speed up data collection (see discussion in [issue #314](https://github.com/DLR-RM/stable-baselines3/issues/314)).
To use SB3 with these tools, you need to wrap the env with tool-specific `VecEnvWrapper` that pre-processes the data for SB3,
you can find links to some of these wrappers in [issue #772](https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002).
- Isaac Lab wrapper: [link](https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/sb3.py)
- Brax: [link](https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7)
- EnvPool: [link](https://github.com/sail-sg/envpool/blob/main/examples/sb3_examples/ppo.py)
- Getting SAC to Work on a Massive Parallel Simulator: <https://araffin.github.io/post/sac-massive-sim/>
## SB3 with DeepMind Control (dm_control)
If you want to use SB3 with [dm_control](https://github.com/google-deepmind/dm_control), you need to use two wrappers (one from [shimmy](https://github.com/Farama-Foundation/Shimmy), one pre-built one) to convert it to a Gymnasium compatible environment:
```python
import shimmy
import stable_baselines3 as sb3
from dm_control import suite
from gymnasium.wrappers import FlattenObservation
# Available envs:
# suite._DOMAINS and suite.dog.SUITE
env = suite.load(domain_name="dog", task_name="run")
gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env))
model = sb3.PPO("MlpPolicy", gym_env, verbose=1)
model.learn(10_000, progress_bar=True)
```
## Record a Video
Record a mp4 video (here using a random agent).
:::{note}
It requires `ffmpeg` or `avconv` to be installed on the machine.
:::
```python
import gymnasium as gym
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
env_id = "CartPole-v1"
video_folder = "logs/videos/"
video_length = 100
vec_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")])
obs = vec_env.reset()
# Record the video starting at the first step
vec_env = VecVideoRecorder(vec_env, video_folder,
record_video_trigger=lambda x: x == 0, video_length=video_length,
name_prefix=f"random-agent-{env_id}")
vec_env.reset()
for _ in range(video_length + 1):
action = [vec_env.action_space.sample()]
obs, _, _, _ = vec_env.step(action)
# Save the video
vec_env.close()
```
## Bonus: Make a GIF of a Trained Agent
```python
import imageio
import numpy as np
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "LunarLander-v3").learn(100_000)
images = []
obs = model.env.reset()
img = model.env.render(mode="rgb_array")
for i in range(350):
images.append(img)
action, _ = model.predict(obs)
obs, _, _ ,_ = model.env.step(action)
img = model.env.render(mode="rgb_array")
imageio.mimsave("lander_a2c.gif", [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
```
[advanced saving and loading]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb
[atari games]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
[hindsight experience replay]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
[monitor training and plotting]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
[multiprocessing]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
[pybullet]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
[rl baselines zoo]: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
================================================
FILE: docs/guide/export.md
================================================
(export)=
# Exporting models
After training an agent, you may want to deploy/use it in another language
or framework, like [tensorflowjs](https://github.com/tensorflow/tfjs).
Stable Baselines3 does not include tools to export models to other frameworks, but
this document aims to cover parts that are required for exporting along with
more detailed stories from users of Stable Baselines3.
## Background
In Stable Baselines3, the controller is stored inside policies which convert
observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC)
contains a policy object which represents the currently learned behavior,
accessible via `model.policy`.
Policies hold enough information to do the inference (i.e. predict actions),
so it is enough to export these policies (cf {ref}`examples <examples>`)
to do inference in another framework.
:::{warning}
When using CNN policies, the observation is normalized during pre-preprocessing.
This pre-processing is done *inside* the policy (dividing by 255 to have values in [0, 1])
:::
## Export to ONNX
If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code:
:::{warning}
The following returns normalized actions and doesn't include the [post-processing](https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377) step that is done with continuous actions (clip or unscale the action to the correct space).
:::
```python
import torch as th
from typing import Tuple
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.policy = policy
def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
# NOTE: Preprocessing is included, but postprocessing
# (clipping/inscaling actions) is not,
# If needed, you also need to transpose the images so that they are channel first
# use deterministic=False if you want to export the stochastic policy
# policy() returns `actions, values, log_prob` for PPO
return self.policy(observation, deterministic=True)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnx_policy = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnx_policy,
dummy_input,
"my_ppo_model.onnx",
opset_version=17,
input_names=["input"],
)
##### Load and test with onnx
import onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
actions, values, log_prob = ort_sess.run(None, {"input": observation})
print(actions, values, log_prob)
# Check that the predictions are the same
with th.no_grad():
print(model.policy(th.as_tensor(observation), deterministic=True))
```
For exporting `MultiInputPolicy`, please have a look at [GH#1873](https://github.com/DLR-RM/stable-baselines3/issues/1873#issuecomment-2710776085).
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
```python
import torch as th
from stable_baselines3 import SAC
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
self.actor = actor
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
return self.actor(observation, deterministic=True)
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=17,
input_names=["input"],
)
##### Load and test with onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_sac_actor.onnx"
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
scaled_action = ort_sess.run(None, {"input": observation})[0]
print(scaled_action)
# Post-process: rescale to correct space
# Rescale the action from [-1, 1] to [low, high]
# low, high = model.action_space.low, model.action_space.high
# post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low))
# Check that the predictions are the same
with th.no_grad():
print(model.actor(th.as_tensor(observation), deterministic=True))
```
For more discussion around the topic, please refer to [GH#383](https://github.com/DLR-RM/stable-baselines3/issues/383) and [GH#1349](https://github.com/DLR-RM/stable-baselines3/issues/1349).
## Trace/Export to C++
You can use PyTorch JIT to trace and save a trained model that can be reused in other applications
(for instance inference code written in C++).
There is a draft PR in the RL Zoo about C++ export: <https://github.com/DLR-RM/rl-baselines3-zoo/pull/228>
```python
# See "ONNX export" for imports and OnnxablePolicy
jit_path = "sac_traced.pt"
# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
##### Load and test with torch
import torch as th
dummy_input = th.randn(1, *observation_size)
loaded_module = th.jit.load(jit_path)
action_jit = loaded_module(dummy_input)
```
## Export to ONNX-JS / ONNX Runtime Web
Official documentation: <https://onnxruntime.ai/docs/tutorials/web/build-web-app.html>
Full example code: <https://github.com/JonathanColetti/CarDodgingGym>
Demo: <https://jonathancoletti.github.io/CarDodgingGym>
The code linked above is a complete example (using car dodging environment) that:
1. Creates/Trains a PPO model
2. Exports the model to ONNX along with normalization stats in JSON
3. Runs in the browser with normalization using onnxruntime-web to achieve similar results
Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS
```python
import torch as th
from stable_baselines3 import SAC
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
self.actor = actor
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to postprocess (unnormalize or renormalize)
return self.actor(observation, deterministic=True)
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=17,
input_names=["input"],
)
```
```javascript
// Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn
import * as ort from 'onnxruntime-web';
async function runInference() {
const session = await ort.InferenceSession.create('my_sac_actor.onnx');
// The observation_size = 3 (for Pendulum-v1)
const inputData = Float32Array.from([0.1, -0.2, 0.3]);
const inputTensor = new ort.Tensor('float32', inputData, [1, 3]);
const results = await session.run({ input: inputTensor });
const outputName = session.outputNames[0];
const action = results[outputName].data;
console.log('Predicted action=', action);
}
runInference();
```
## Export to TensorFlow.js
:::{warning}
As of November 2025, [onnx2tf](https://github.com/PINTO0309/onnx2tf) does not support TensorFlow.js. Therefore, [tfjs-converter](https://github.com/tensorflow/tfjs-converter) is used instead. However, tfjs-converter is not currently maintained and requires older opsets and TensorFlow versions.
:::
In order for this to work, you have to do multiple conversions: SB3 => ONNX => TensorFlow => TensorFlow.js.
The opset version needs to be changed for the conversion (`opset_version=14` is currently required). Please refer to the code above for more stable usage with a higher opset.
The following is a simple example that showcases the full conversion + inference.
Please refer to the previous sections for the first step (SB3 => ONNX).
The main difference is that you need to specify `opset_version=14`.
```python
# Tested with python3.10
# Then install these dependencies in a fresh env
"""
pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26
"""
# Then run this codeblock
# If there are no errors (the folder is structure correctly) then
"""
# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
"""
# If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow
"""
pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs
"""
# And retry with and it should work (do not rerun this codeblock)
"""
tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
"""
import onnx
import onnx_tf.backend
import tensorflow as tf
ONNX_FILE_PATH = "my_sac_actor.onnx"
MODEL_PATH = "tf_model"
onnx_model = onnx.load(ONNX_FILE_PATH)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))
print('Converting ONNX to TF...')
tf_rep = onnx_tf.backend.prepare(onnx_model)
tf_rep.export_graph(MODEL_PATH)
# After this do not forget to use `tensorflowjs_converter`
```
```javascript
import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.15.0/+esm';
// Post processing not included
async function runInference() {
const MODEL_URL = './tfjs_model/model.json';
const model = await tf.loadGraphModel(MODEL_URL);
// Observation_size is 3 for Pendulum-v1
const inputData = [1.0, 0.0, 0.0];
const inputTensor = tf.tensor2d([inputData], [1, 3]);
const resultTensor = model.execute(inputTensor);
const action = await resultTensor.data();
console.log('Predicted action=', action);
inputTensor.dispose();
resultTensor.dispose();
}
runInference();
```
## Export to TFLite / Coral (Edge TPU)
Full example code: <https://github.com/chunky/sb3_to_coral>
Google created a chip called the "Coral" for deploying AI to the
edge. It's available in a variety of form factors, including USB (using
the Coral on a Raspberry Pi, with a SB3-developed model, was the original
motivation for the code example above).
The Coral chip is fast, with very low power consumption, but only has limited
on-device training abilities. More information is on the webpage here:
<https://coral.ai>.
To deploy to a Coral, one must work via TFLite, and quantize the
network to reflect the Coral's capabilities. The full chain to go from
SB3 to Coral is: SB3 (Torch) => ONNX => TensorFlow => TFLite => Coral.
The code linked above is a complete, minimal, example that:
1. Creates a model using SB3
2. Follows the path of exports all the way to TFLite and Google Coral
3. Demonstrates the forward pass for most exported variants
There are a number of pitfalls along the way to the complete conversion
that this example covers, including:
- Making the Gym's observation work with ONNX properly
- Quantising the TFLite model appropriately to align with Gym
while still taking advantage of Coral
- Using OnnxablePolicy described as described in the above example
## Manual export
You can also manually export required parameters (weights) and construct the
network in your desired framework.
You can access parameters of the model via agents'
{func}`get_parameters <stable_baselines3.common.base_class.BaseAlgorithm.get_parameters>` function.
As policies are also PyTorch modules, you can also access `model.policy.state_dict()` directly.
To find the architecture of the networks for each algorithm, best is to check the `policies.py` file located
in their respective folders.
:::{note}
In most cases, we recommend using PyTorch methods `state_dict()` and `load_state_dict()` from the policy,
unless you need to access the optimizers' state dict too. In that case, you need to call `get_parameters()`.
:::
## SBX (SB3 + Jax) Export
As an example of manual export, {ref}`Stable Baselines Jax (SBX) <sbx>` policies can be exported to ONNX
by using an intermediate PyTorch representation, as shown in the following example:
```python
import numpy as np
import sbx
import torch as th
class TorchPolicy(th.nn.Module):
def __init__(self, obs_dim: int, hidden_dim: int, act_dim: int):
super().__init__()
self.net = th.nn.Sequential(
th.nn.Linear(obs_dim, hidden_dim),
th.nn.Tanh(),
th.nn.Linear(hidden_dim, hidden_dim),
th.nn.Tanh(),
th.nn.Linear(hidden_dim, act_dim),
)
def forward(self, x: th.Tensor) -> th.Tensor:
return self.net(x)
model = sbx.PPO("MlpPolicy", "Pendulum-v1")
# Also possible: load a trained model
# model = sbx.PPO.load("PathToTrainedModel.zip")
params = model.policy.actor_state.params["params"]
# For debug:
print("=== SBX params ===")
for key, value in params.items():
if isinstance(value, dict):
for name, val in value.items():
print(f"{key}.{name}: {val.shape}", end=" ")
else:
print(f"{key}: {value.shape}", end=" ")
print("\n" + "=" * 20 + "\n")
obs_dim = model.observation_space.shape
act_dim = model.action_space.shape
# Number of units in the hidden layers (assume a network architecture like [64, 64])
hidden_dim = params["Dense_0"]["kernel"].shape[1]
# map params to torch state_dict keys
num_layers = len([k for k in params.keys() if k.startswith("Dense_")])
state_dict = {}
for i in range(num_layers):
layer_name = f"Dense_{i}"
state_dict[f"net.{i * 2}.bias"] = th.from_numpy(np.array(params[layer_name]["bias"]))
state_dict[f"net.{i * 2}.weight"] = th.from_numpy(np.array(params[layer_name]["kernel"].T))
torch_policy = TorchPolicy(obs_dim[0], hidden_dim, act_dim[0])
print("=== Torch params ===")
print(" ".join(f"{key}:{tuple(value.shape)}" for key, value in torch_policy.named_parameters()))
print("=" * 20 + "\n")
torch_policy.load_state_dict(state_dict)
torch_policy.eval()
dummy_input = th.zeros((1, *obs_dim))
# Use normal Torch export
th.onnx.export(
torch_policy,
(dummy_input,),
"my_ppo_actor.onnx",
opset_version=18,
input_names=["input"],
output_names=["action"],
)
##### Load and test with onnx
import onnxruntime as ort
onnx_path = "my_ppo_actor.onnx"
ort_sess = ort.InferenceSession(onnx_path)
observation = np.random.random((1, *obs_dim)).astype(np.float32)
action = ort_sess.run(None, {"input": observation})[0]
print(action)
sbx_action, _ = model.predict(observation, deterministic=True)
with th.no_grad():
torch_action = torch_policy(th.as_tensor(observation))
# Check that the predictions are the same
assert np.allclose(sbx_action, action)
assert np.allclose(sbx_action, torch_action.numpy())
```
================================================
FILE: docs/guide/imitation.md
================================================
(imitation)=
# Imitation Learning
The [imitation](https://github.com/HumanCompatibleAI/imitation) library implements
imitation learning algorithms on top of Stable-Baselines3, including:
- Behavioral Cloning
- [DAgger](https://arxiv.org/abs/1011.0686) with synthetic examples
- [Adversarial Inverse Reinforcement Learning](https://arxiv.org/abs/1710.11248) (AIRL)
- [Generative Adversarial Imitation Learning](https://arxiv.org/abs/1606.03476) (GAIL)
- [Deep RL from Human Preferences](https://arxiv.org/abs/1706.03741) (DRLHP)
You can install imitation with `pip install imitation`. The [imitation
documentation](https://imitation.readthedocs.io/en/latest/) has more details
on how to use the library, including [a quick start guide](https://imitation.readthedocs.io/en/latest/getting-started/first-steps.html)
for the impatient.
================================================
FILE: docs/guide/install.md
================================================
(install)=
# Installation
## Prerequisites
Stable-Baselines3 requires python 3.10+ and PyTorch >= 2.3
### Windows
We recommend using [Anaconda](https://conda.io/docs/user-guide/install/windows.html) for Windows users for easier installation of Python packages and required libraries. You need an environment with Python version 3.8 or above.
For a quick start you can move straight to installing Stable-Baselines3 in the next step.
:::{note}
Trying to create Atari environments may result in vague errors related to missing DLL files and modules. This is an
issue with atari-py package. [See this discussion for more information](https://github.com/openai/atari-py/issues/65).
:::
### Stable Release
To install Stable Baselines3 with pip, execute:
```bash
pip install stable-baselines3[extra]
```
:::{note}
Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` [More information](https://stackoverflow.com/a/30539963).
:::
This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on Atari games. If you do not need those, you can use:
```bash
pip install stable-baselines3
```
:::{note}
If you need to work with OpenCV on a machine without a X-server (for instance inside a docker image),
you will need to install `opencv-python-headless`, see [issue #298](https://github.com/DLR-RM/stable-baselines3/issues/298).
:::
## Bleeding-edge version
```bash
pip install git+https://github.com/DLR-RM/stable-baselines3
```
with extras:
```bash
pip install "stable_baselines3[extra,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
```
## Development version
To contribute to Stable-Baselines3, with support for running tests and building the documentation.
```bash
git clone https://github.com/DLR-RM/stable-baselines3 && cd stable-baselines3
pip install -e .[docs,tests,extra]
```
## Using Docker Images
If you are looking for docker images with stable-baselines already installed in it,
we recommend using images from [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo).
Otherwise, the following images contained all the dependencies for stable-baselines3 but not the stable-baselines3 package itself.
They are made for development.
### Use Built Images
GPU image (requires [nvidia-docker]):
```bash
docker pull stablebaselines/stable-baselines3
```
CPU only:
```bash
docker pull stablebaselines/stable-baselines3-cpu
```
### Build the Docker Images
Build GPU image (with nvidia-docker):
```bash
make docker-gpu
```
Build CPU image:
```bash
make docker-cpu
```
Note: if you are using a proxy, you need to pass extra params during
build and do some [tweaks]:
```bash
--network=host --build-arg HTTP_PROXY=http://your.proxy.fr:8080/ --build-arg http_proxy=http://your.proxy.fr:8080/ --build-arg HTTPS_PROXY=https://your.proxy.fr:8080/ --build-arg https_proxy=https://your.proxy.fr:8080/
```
### Run the images (CPU/GPU)
Run the nvidia-docker GPU image
```bash
docker run -it --runtime=nvidia --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3 bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/'
```
Or, with the shell file:
```bash
./scripts/run_docker_gpu.sh pytest tests/
```
Run the docker CPU image
```bash
docker run -it --rm --network host --ipc=host --name test --mount src="$(pwd)",target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu bash -c 'cd /home/mamba/stable-baselines3/ && pytest tests/'
```
Or, with the shell file:
```bash
./scripts/run_docker_cpu.sh pytest tests/
```
Explanation of the docker command:
- `docker run -it` create an instance of an image (=container), and
run it interactively (so ctrl+c will work)
- `--rm` option means to remove the container once it exits/stops
(otherwise, you will have to use `docker rm`)
- `--network host` don't use network isolation, this allows to use
tensorboard/visdom on host machine
- `--ipc=host` Use the host system’s IPC namespace. IPC (POSIX/SysV IPC) namespace provides
separation of named shared memory segments, semaphores and message
queues.
- `--name test` give explicitly the name `test` to the container,
otherwise it will be assigned a random name
- `--mount src=...` give access of the local directory (`pwd`
command) to the container (it will be map to `/home/mamba/stable-baselines`), so
all the logs created in the container in this folder will be kept
- `bash -c '...'` Run command inside the docker image, here run the tests
(`pytest tests/`)
[nvidia-docker]: https://github.com/NVIDIA/nvidia-docker
[tweaks]: https://stackoverflow.com/questions/23111631/cannot-download-docker-images-behind-a-proxy
================================================
FILE: docs/guide/integrations.md
================================================
(integrations)=
# Integrations
## Weights & Biases
Weights & Biases provides a callback for experiment tracking that allows to visualize and share results.
The full documentation is available here: <https://docs.wandb.ai/models/integrations/stable-baselines-3>
```python
import gymnasium as gym
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3 import PPO
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_id": "CartPole-v1",
}
run = wandb.init(
project="sb3",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
# monitor_gym=True, # auto-upload the videos of agents playing the game
# save_code=True, # optional
)
model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
model_save_path=f"models/{run.id}",
verbose=2,
),
)
run.finish()
```
## Hugging Face 🤗
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
You can see the list of stable-baselines3 saved models here: <https://huggingface.co/models?library=stable-baselines3>
Most of them are available via the RL Zoo.
Official pre-trained models are saved in the SB3 organization on the hub: <https://huggingface.co/sb3>
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3
[here](https://colab.research.google.com/github/huggingface/huggingface_sb3/blob/main/notebooks/sb3_huggingface.ipynb).
### Installation
```bash
pip install huggingface_sb3
```
:::{note}
If you use the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo), pushing/loading models from the hub are already integrated:
```bash
# Download model and save it into the logs/ folder
# Only use TRUST_REMOTE_CODE=True with HF models that can be trusted (here the SB3 organization)
TRUST_REMOTE_CODE=True python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v3 -orga sb3 -f logs/
# Test the agent
python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v3 -f logs/
# Push model, config and hyperparameters to the hub
python -m rl_zoo3.push_to_hub --algo a2c --env LunarLander-v3 -f logs/ -orga sb3 -m "Initial commit"
```
:::
### Download a model from the Hub
You need to copy the repo-id that contains your saved model.
For instance `sb3/demo-hf-CartPole-v1`:
```python
import os
import gymnasium as gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
# Allow the use of `pickle.load()` when downloading model from the hub
# Please make sure that the organization from which you download can be trusted
os.environ["TRUST_REMOTE_CODE"] = "True"
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)
model = PPO.load(checkpoint)
# Evaluate the agent and watch it
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
```
You need to define two parameters:
- `repo-id`: the name of the Hugging Face repo you want to download.
- `filename`: the file you want to download.
### Upload a model to the Hub
You can easily upload your models using two different functions:
1. `package_to_hub()`: save the model, evaluate it, generate a model card and record a replay video of your agent before pushing the complete repo to the Hub.
2. `push_to_hub()`: simply push a file to the Hub.
First, you need to be logged in to Hugging Face to upload a model:
- If you're using Colab/Jupyter Notebooks:
```python
from huggingface_hub import notebook_login
notebook_login()
```
- Otherwise:
```bash
huggingface-cli login
```
Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo `sb3/demo-hf-CartPole-v1`
#### With `package_to_hub()`
```python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from huggingface_sb3 import package_to_hub
# Create the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)
# Create the evaluation environment
eval_env = make_vec_env(env_id, n_envs=1)
# Instantiate the agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(5000))
# This method saves, evaluates, generates a model card and records a replay video of your agent before pushing the repo to the hub
package_to_hub(model=model,
model_name="ppo-CartPole-v1",
model_architecture="PPO",
env_id=env_id,
eval_env=eval_env,
repo_id="sb3/demo-hf-CartPole-v1",
commit_message="Test commit")
```
You need to define seven parameters:
- `model`: your trained model.
- `model_architecture`: name of the architecture of your model (DQN, PPO, A2C, SAC…).
- `env_id`: name of the environment.
- `eval_env`: environment used to evaluate the agent.
- `repo-id`: the name of the Hugging Face repo you want to create or update. It’s \<your huggingface username>/\<the repo name>.
- `commit-message`.
- `filename`: the file you want to push to the Hub.
#### With `push_to_hub()`
```python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from huggingface_sb3 import push_to_hub
# Create the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)
# Instantiate the agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(5000))
# Save the model
model.save("ppo-CartPole-v1")
# Push this saved model .zip file to the hf repo
# If this repo does not exist it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
push_to_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
commit_message="Added CartPole-v1 model trained with PPO",
)
```
You need to define three parameters:
- `repo-id`: the name of the Hugging Face repo you want to create or update. It’s \<your huggingface username>/\<the repo name>.
- `filename`: the file you want to push to the Hub.
- `commit-message`.
## MLFLow
If you want to use [MLFLow](https://github.com/mlflow/mlflow) to track your SB3 experiments,
you can adapt the following code which defines a custom logger output:
```python
import sys
from typing import Any, Dict, Tuple, Union
import mlflow
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
class MLflowOutputFormat(KVWriter):
"""
Dumps key/value pairs into MLflow's numeric format.
"""
def write(
self,
key_values: Dict[str, Any],
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
step: int = 0,
) -> None:
for (key, value), (_, excluded) in zip(
sorted(key_values.items()), sorted(key_excluded.items())
):
if excluded is not None and "mlflow" in excluded:
continue
if isinstance(value, np.ScalarType):
if not isinstance(value, str):
mlflow.log_metric(key, value, step)
loggers = Logger(
folder=None,
output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
)
with mlflow.start_run():
model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
# Set custom logger
model.set_logger(loggers)
model.learn(total_timesteps=10000, log_interval=1)
```
================================================
FILE: docs/guide/migration.md
================================================
(migration)=
# Migrating from Stable-Baselines
This is a guide to migrate from Stable-Baselines (SB2) to Stable-Baselines3 (SB3).
It also references the main changes.
## Overview
Overall Stable-Baselines3 (SB3) keeps the high-level API of Stable-Baselines (SB2).
Most of the changes are to ensure more consistency and are internal ones.
Because of the backend change, from Tensorflow to PyTorch, the internal code is much more readable and easy to debug
at the cost of some speed (dynamic graph vs static graph., see [Issue #90](https://github.com/DLR-RM/stable-baselines3/issues/90))
However, the algorithms were extensively benchmarked on Atari games and continuous control PyBullet envs
(see [Issue #48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [Issue #49](https://github.com/DLR-RM/stable-baselines3/issues/49))
so you should not expect performance drop when switching from SB2 to SB3.
## How to migrate?
In most cases, replacing `from stable_baselines` with `from stable_baselines3` will be sufficient.
Some files were moved to the common folder (cf below) and could result to import errors.
Some algorithms were removed because of their complexity to improve the maintainability of the project.
We recommend reading this guide carefully to understand all the changes that were made.
You can also take a look at the [rl-zoo3](https://github.com/DLR-RM/rl-baselines3-zoo) and compare the imports
to the [rl-zoo](https://github.com/araffin/rl-baselines-zoo) of SB2 to have a concrete example of successful migration.
:::{note}
If you experience massive slow-down switching to PyTorch, you may need to play with the number of threads used,
using `torch.set_num_threads(1)` or `OMP_NUM_THREADS=1`, see [issue #122](https://github.com/DLR-RM/stable-baselines3/issues/122)
and [issue #90](https://github.com/DLR-RM/stable-baselines3/issues/90).
:::
## Breaking Changes
- SB3 requires python 3.7+ (instead of python 3.5+ for SB2)
- Dropped MPI support
- Dropped layer normalized policies (`MlpLnLstmPolicy`, `CnnLnLstmPolicy`)
- LSTM policies (`MlpLstmPolicy`, `CnnLstmPolicy`) are not supported for the time being
(see [PR #53](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53) for a recurrent PPO implementation)
- Dropped parameter noise for DDPG and DQN
- PPO is now closer to the original implementation (no clipping of the value function by default), cf PPO section below
- Orthogonal initialization is only used by A2C/PPO
- The features extractor (CNN extractor) is shared between policy and q-networks for DDPG/SAC/TD3 and only the policy loss used to update it (much faster)
- Tensorboard legacy logging was dropped in favor of having one logger for the terminal and Tensorboard (cf {ref}`Tensorboard integration <tensorboard>`)
- We dropped ACKTR/ACER support because of their complexity compared to simpler alternatives (PPO, SAC, TD3) performing as good.
- We dropped GAIL support as we are focusing on model-free RL only, you can however take a look at the {ref}`imitation project <imitation>` which implements
GAIL and other imitation learning algorithms on top of SB3.
- `action_probability` is currently not implemented in the base class
- `pretrain()` method for behavior cloning was removed (see [issue #27](https://github.com/DLR-RM/stable-baselines3/issues/27))
You can take a look at the [issue about SB3 implementation design](https://github.com/hill-a/stable-baselines/issues/576) for more details.
### Moved Files
- `bench/monitor.py` -> `common/monitor.py`
- `logger.py` -> `common/logger.py`
- `results_plotter.py` -> `common/results_plotter.py`
- `common/cmd_util.py` -> `common/env_util.py`
Utility functions are no longer exported from `common` module, you should import them with their absolute path, e.g.:
```python
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
```
instead of `from stable_baselines3.common import make_atari_env`
### Changes and renaming in parameters
#### Base-class (all algorithms)
- `load_parameters` -> `set_parameters`
- `get/set_parameters` return a dictionary mapping object names
to their respective PyTorch tensors and other objects representing
their parameters, instead of simpler mapping of parameter name to
a NumPy array. These functions also return PyTorch tensors rather
than NumPy arrays.
#### Policies
- `cnn_extractor` -> `features_extractor`, as `features_extractor` in now used with `MlpPolicy` too
#### A2C
- `epsilon` -> `rms_prop_eps`
- `lr_schedule` is part of `learning_rate` (it can be a callable).
- `alpha`, `momentum` are modifiable through `policy_kwargs` key `optimizer_kwargs`.
:::{warning}
PyTorch implementation of RMSprop [differs from Tensorflow's](https://github.com/pytorch/pytorch/issues/23796),
which leads to [different and potentially more unstable results](https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241).
Use `stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike` optimizer to match the results
with TensorFlow's implementation. This can be done through `policy_kwargs`: `A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, optimizer_kwargs=dict(eps=1e-5)))`
:::
#### PPO
- `cliprange` -> `clip_range`
- `cliprange_vf` -> `clip_range_vf`
- `nminibatches` -> `batch_size`
:::{warning}
`nminibatches` gave different batch size depending on the number of environments: `batch_size = (n_steps * n_envs) // nminibatches`
:::
- `clip_range_vf` behavior for PPO is slightly different: Set it to `None` (default) to deactivate clipping (in SB2, you had to pass `-1`, `None` meant to use `clip_range` for the clipping)
- `lam` -> `gae_lambda`
- `noptepochs` -> `n_epochs`
PPO default hyperparameters are the one tuned for continuous control environment.
We recommend taking a look at the [RL Zoo](rl_zoo.md) for hyperparameters tuned for Atari games.
#### DQN
Only the vanilla DQN is implemented right now but extensions will follow.
Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
#### DDPG
DDPG now follows the same interface as SAC/TD3.
For state/reward normalization, you should use `VecNormalize` as for all other algorithms.
#### SAC/TD3
SAC/TD3 now accept any number of critics, e.g. `policy_kwargs=dict(n_critics=3)`, instead of only two before.
:::{note}
SAC/TD3 default hyperparameters (including network architecture) now match the ones from the original papers.
DDPG is using TD3 defaults.
:::
#### SAC
SAC implementation matches the latest version of the original implementation: it uses two Q function networks and two target Q function networks
instead of two Q function networks and one Value function network (SB2 implementation, first version of the original implementation).
Despite this change, no change in performance should be expected.
:::{note}
SAC `predict()` method has now `deterministic=False` by default for consistency.
To match SB2 behavior, you need to explicitly pass `deterministic=True`
:::
#### HER
The `HER` implementation now only supports online sampling of the new goals. This is done in a vectorized version.
The goal selection strategy `RANDOM` is no longer supported.
### New logger API
- Methods were renamed in the logger:
- `logkv` -> `record`, `writekvs` -> `write`, `writeseq` -> `write_sequence`,
- `logkvs` -> `record_dict`, `dumpkvs` -> `dump`,
- `getkvs` -> `get_log_dict`, `logkv_mean` -> `record_mean`,
### Internal Changes
Please read the {ref}`Developer Guide <developer>` section.
## New Features (SB3 vs SB2)
- Much cleaner and consistent base code (and no more warnings =D!) and static type checks
- Independent saving/loading/predict for policies
- A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default)
- Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows using RL directly on real robots (cf <https://arxiv.org/abs/2005.05719>)
- Better saving/loading: optimizers are now included in the saved parameters and there are two new methods `save_replay_buffer` and `load_replay_buffer` for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3)
- You can pass `optimizer_class` and `optimizer_kwargs` to `policy_kwargs` in order to easily
customize optimizers
- Seeding now works properly to have deterministic results
- Replay buffer does not grow, allocate everything at build time (faster)
- We added a memory efficient replay buffer variant (pass `optimize_memory_usage=True` to the constructor), it reduces drastically the memory used especially when using images
- You can specify an arbitrary number of critics for SAC/TD3 (e.g. `policy_kwargs=dict(n_critics=3)`)
================================================
FILE: docs/guide/plotting.md
================================================
(plotting)=
# Plotting
Stable Baselines3 provides utilities for plotting training results, allowing you to monitor and visualize your agent's learning progress.
The plotting functionality is provided by the `results_plotter` module, which can load monitor files created during training and generate various plots.
:::{note}
We recommend using the
[RL Baselines3 Zoo plotting scripts](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html)
which provide plotting capabilities with confidence intervals, and publication-ready visualizations.
:::
## Recommended Approach: RL Baselines3 Zoo Plotting
The [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) provides scripts that allows to compare results across different environments and have publication-ready plots with confidence intervals.
The three main plotting scripts are:
- [plot_train.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_train.py): For training plots
- [all_plots.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/all_plots.py): For evaluation plots, to post-process the result
- [plot_from_file.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_from_file.py): For more advanced plotting from post-processed results
These scripts offer features that are not included in the basic SB3 plotting utilities.
### Installation
First, install RL Baselines3 Zoo:
```bash
pip install 'rl_zoo3[plots]'
```
### Basic Training Plot Examples
```bash
# Train an agent
python -m rl_zoo3.train --algo ppo --env CartPole-v1 -f logs/
# Plot training results for a single algorithm
python -m rl_zoo3.plots.plot_train --algo ppo --env CartPole-v1 --exp-folder logs/
```
### Evaluation and Comparison Plots
```bash
# Generate evaluation plots and save post-processed results
# in `logs/demo_plots.pkl` in order to use `plot_from_file`
python -m rl_zoo3.plots.all_plots --algo ppo sac -e Pendulum-v1 -f logs/ -o logs/demo_plots
# More advanced plotting from post-processed results (with confidence intervals)
python -m rl_zoo3.plots.plot_from_file -i logs/demo_plots.pkl --rliable --ci-size 0.95
```
For more examples, please read the
[RL Baselines3 Zoo plotting guide](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html).
## Real-Time Monitoring
For real-time monitoring during training, consider using the plotting functions within callbacks
(see the [Callbacks guide](callbacks.md)) or integrating with tools like [Tensorboard](tensorboard.md) or Weights & Biases
(see the [Integrations guide](integrations.md)).
## Monitor File Format
The `Monitor` wrapper saves training data in CSV format with the following columns:
- `r`: Episode return (sum of rewards for one episode)
- `l`: Episode length (number of steps)
- `t`: Timestamp (wall-clock time when episode ended)
Additional columns may be present if you log custom metrics in the environment's info dict and pass their names via the `info_keywords` parameter.
:::{note}
The plotting functions automatically handle multiple monitor files from the same directory.
This occurs when using vectorized environments. Episodes are loaded and sorted by timestamp
to ensure they are in the correct chronological order.
:::
## Basic SB3 Plotting (Simple Use Cases)
### Basic Plotting: Single Training Run
The simplest way to plot training results is to use the `plot_results` function after training an agent.
This function reads the monitor files created by the `Monitor` wrapper and plots the episode rewards over time.
```python
import os
import gymnasium as gym
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import plot_results
from stable_baselines3.common import results_plotter
# Create log directory
log_dir = "tmp/"
os.makedirs(log_dir, exist_ok=True)
# Create and wrap the environment with Monitor
env = gym.make("CartPole-v1")
env = Monitor(env, log_dir)
# Train the agent
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=20_000)
# Plot the results
plot_results([log_dir], 20_000, results_plotter.X_TIMESTEPS, "PPO CartPole")
plt.show()
```
### Different Plotting Modes
The plotting functions support three different x-axis modes:
- `X_TIMESTEPS`: Plot rewards vs. timesteps (default)
- `X_EPISODES`: Plot rewards vs. episode number
- `X_WALLTIME`: Plot rewards vs. wall-clock time in hours
```python
import matplotlib.pyplot as plt
from stable_baselines3.common import results_plotter
# Plot by timesteps (shows sample efficiency)
# plot_results([log_dir], None, results_plotter.X_TIMESTEPS, "Rewards vs Timesteps")
# By Episodes
plot_results([log_dir], None, results_plotter.X_EPISODES, "Rewards vs Episodes")
# plot_results([log_dir], None, results_plotter.X_WALLTIME, "Rewards vs Time")
plt.tight_layout()
plt.show()
```
### Advanced Plotting with Manual Data Processing
For more control over the plotting, you can use the underlying functions to process the data manually:
```python
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3.common.monitor import load_results
from stable_baselines3.common.results_plotter import ts2xy, window_func
# Load the results
df = load_results(log_dir)
# Convert dataframe (x=timesteps, y=episodic return)
x, y = ts2xy(df, "timesteps")
# Plot raw data
plt.figure(figsize=(10, 6))
plt.subplot(2, 1, 1)
plt.scatter(x, y, s=2, alpha=0.6)
plt.xlabel("Timesteps")
plt.ylabel("Episode Reward")
plt.title("Raw Episode Rewards")
# Plot smoothed data with custom window
plt.subplot(2, 1, 2)
if len(x) >= 50: # Only smooth if we have enough data
x_smooth, y_smooth = window_func(x, y, 50, np.mean)
plt.plot(x_smooth, y_smooth, linewidth=2)
plt.xlabel("Timesteps")
plt.ylabel("Average Episode Reward (50-episode window)")
plt.title("Smoothed Episode Rewards")
plt.tight_layout()
plt.show()
```
================================================
FILE: docs/guide/quickstart.md
================================================
(quickstart)=
# Getting Started
:::{note}
Stable-Baselines3 (SB3) uses [vectorized environments (VecEnv)](vec_envs.md) internally.
Please read the associated section to learn more about its features and differences compared to a single Gym environment.
:::
Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
Here is a quick example of how to train and run A2C on a CartPole environment:
```python
import gymnasium as gym
from stable_baselines3 import A2C
env = gym.make("CartPole-v1", render_mode="rgb_array")
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render("human")
# VecEnv resets automatically
# if done:
# obs = vec_env.reset()
```
:::{note}
You can find explanations about the logger output and names in the {ref}`Logger <logger>` section.
:::
Or just train a model with a one line if
[the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if
the policy is registered:
```python
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "CartPole-v1").learn(10_000)
```
================================================
FILE: docs/guide/rl.md
================================================
(rl)=
# Reinforcement Learning Resources
Stable-Baselines3 assumes that you already understand the basic concepts of Reinforcement Learning (RL).
However, if you want to learn about RL, there are several good resources to get started:
- [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/)
- [The Deep Reinforcement Learning Course](https://huggingface.co/learn/deep-rl-course/unit0/introduction)
- [David Silver's course](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html)
- [Lilian Weng's blog](https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html)
- [Berkeley's Deep RL Bootcamp](https://sites.google.com/view/deep-rl-bootcamp/lectures)
- [Berkeley's Deep Reinforcement Learning course](http://rail.eecs.berkeley.edu/deeprlcourse/)
- [DQN tutorial](https://github.com/araffin/rlss23-dqn-tutorial)
- [Decisions & Dragons - FAQ for RL foundations](https://www.decisionsanddragons.com)
- [More resources](https://github.com/dennybritz/reinforcement-learning)
================================================
FILE: docs/guide/rl_tips.md
================================================
(rl-tips)=
# Reinforcement Learning Tips and Tricks
The aim of this section is to help you run reinforcement learning experiments.
It covers general advice about RL (where to start, which algorithm to choose, how to evaluate an algorithm, ...),
as well as tips and tricks when using a custom environment or implementing an RL algorithm.
:::{note}
We have a [video on YouTube](https://www.youtube.com/watch?v=Ikngt0_DXJg) that covers
this section in more details. You can also find the [slides here](https://araffin.github.io/slides/rlvs-tips-tricks/).
:::
:::{note}
We also have a [video on Designing and Running Real-World RL Experiments](https://youtu.be/eZ6ZEpCi6D8), slides [can be found online](https://araffin.github.io/slides/design-real-rl-experiments/).
:::
## General advice when using Reinforcement Le
gitextract_vxq3k1wk/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.yml
│ │ ├── custom_env.yml
│ │ ├── documentation.yml
│ │ ├── feature_request.yml
│ │ └── question.yml
│ ├── PULL_REQUEST_TEMPLATE.md
│ └── workflows/
│ └── ci.yml
├── .gitignore
├── .readthedocs.yml
├── CITATION.bib
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── Makefile
├── NOTICE
├── README.md
├── docs/
│ ├── Makefile
│ ├── README.md
│ ├── _static/
│ │ └── css/
│ │ └── baselines_theme.css
│ ├── common/
│ │ ├── atari_wrappers.md
│ │ ├── distributions.md
│ │ ├── env_checker.md
│ │ ├── env_util.md
│ │ ├── envs.md
│ │ ├── evaluation.md
│ │ ├── logger.md
│ │ ├── monitor.md
│ │ ├── noise.md
│ │ └── utils.md
│ ├── conda_env.yml
│ ├── conf.py
│ ├── guide/
│ │ ├── algos.md
│ │ ├── callbacks.md
│ │ ├── checking_nan.md
│ │ ├── custom_env.md
│ │ ├── custom_policy.md
│ │ ├── developer.md
│ │ ├── examples.md
│ │ ├── export.md
│ │ ├── imitation.md
│ │ ├── install.md
│ │ ├── integrations.md
│ │ ├── migration.md
│ │ ├── plotting.md
│ │ ├── quickstart.md
│ │ ├── rl.md
│ │ ├── rl_tips.md
│ │ ├── rl_zoo.md
│ │ ├── save_format.md
│ │ ├── sb3_contrib.md
│ │ ├── sbx.md
│ │ ├── tensorboard.md
│ │ └── vec_envs.md
│ ├── index.rst
│ ├── make.bat
│ ├── misc/
│ │ ├── changelog.md
│ │ └── projects.md
│ ├── modules/
│ │ ├── a2c.md
│ │ ├── base.md
│ │ ├── ddpg.md
│ │ ├── dqn.md
│ │ ├── her.md
│ │ ├── ppo.md
│ │ ├── sac.md
│ │ └── td3.md
│ └── spelling_wordlist.txt
├── pyproject.toml
├── scripts/
│ ├── build_docker.sh
│ ├── run_docker_cpu.sh
│ ├── run_docker_gpu.sh
│ └── run_tests.sh
├── setup.py
├── stable_baselines3/
│ ├── __init__.py
│ ├── a2c/
│ │ ├── __init__.py
│ │ ├── a2c.py
│ │ └── policies.py
│ ├── common/
│ │ ├── __init__.py
│ │ ├── atari_wrappers.py
│ │ ├── base_class.py
│ │ ├── buffers.py
│ │ ├── callbacks.py
│ │ ├── distributions.py
│ │ ├── env_checker.py
│ │ ├── env_util.py
│ │ ├── envs/
│ │ │ ├── __init__.py
│ │ │ ├── bit_flipping_env.py
│ │ │ ├── identity_env.py
│ │ │ └── multi_input_envs.py
│ │ ├── evaluation.py
│ │ ├── logger.py
│ │ ├── monitor.py
│ │ ├── noise.py
│ │ ├── off_policy_algorithm.py
│ │ ├── on_policy_algorithm.py
│ │ ├── policies.py
│ │ ├── preprocessing.py
│ │ ├── results_plotter.py
│ │ ├── running_mean_std.py
│ │ ├── save_util.py
│ │ ├── sb2_compat/
│ │ │ ├── __init__.py
│ │ │ └── rmsprop_tf_like.py
│ │ ├── torch_layers.py
│ │ ├── type_aliases.py
│ │ ├── utils.py
│ │ └── vec_env/
│ │ ├── __init__.py
│ │ ├── base_vec_env.py
│ │ ├── dummy_vec_env.py
│ │ ├── patch_gym.py
│ │ ├── stacked_observations.py
│ │ ├── subproc_vec_env.py
│ │ ├── util.py
│ │ ├── vec_check_nan.py
│ │ ├── vec_extract_dict_obs.py
│ │ ├── vec_frame_stack.py
│ │ ├── vec_monitor.py
│ │ ├── vec_normalize.py
│ │ ├── vec_transpose.py
│ │ └── vec_video_recorder.py
│ ├── ddpg/
│ │ ├── __init__.py
│ │ ├── ddpg.py
│ │ └── policies.py
│ ├── dqn/
│ │ ├── __init__.py
│ │ ├── dqn.py
│ │ └── policies.py
│ ├── her/
│ │ ├── __init__.py
│ │ ├── goal_selection_strategy.py
│ │ └── her_replay_buffer.py
│ ├── ppo/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── ppo.py
│ ├── py.typed
│ ├── sac/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── sac.py
│ ├── td3/
│ │ ├── __init__.py
│ │ ├── policies.py
│ │ └── td3.py
│ └── version.txt
└── tests/
├── __init__.py
├── test_buffers.py
├── test_callbacks.py
├── test_cnn.py
├── test_custom_policy.py
├── test_deterministic.py
├── test_dict_env.py
├── test_distributions.py
├── test_env_checker.py
├── test_envs.py
├── test_gae.py
├── test_her.py
├── test_identity.py
├── test_logger.py
├── test_monitor.py
├── test_n_step_replay.py
├── test_predict.py
├── test_preprocessing.py
├── test_run.py
├── test_save_load.py
├── test_sde.py
├── test_spaces.py
├── test_tensorboard.py
├── test_train_eval_mode.py
├── test_utils.py
├── test_vec_check_nan.py
├── test_vec_envs.py
├── test_vec_extract_dict_obs.py
├── test_vec_monitor.py
├── test_vec_normalize.py
└── test_vec_stacked_obs.py
SYMBOL INDEX (1248 symbols across 82 files)
FILE: docs/conf.py
function setup (line 120) | def setup(app):
FILE: stable_baselines3/__init__.py
function HER (line 18) | def HER(*args, **kwargs):
FILE: stable_baselines3/a2c/a2c.py
class A2C (line 16) | class A2C(OnPolicyAlgorithm):
method __init__ (line 66) | def __init__(
method train (line 132) | def train(self) -> None:
method learn (line 192) | def learn(
FILE: stable_baselines3/common/atari_wrappers.py
class StickyActionEnv (line 17) | class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
method __init__ (line 28) | def __init__(self, env: gym.Env, action_repeat_probability: float) -> ...
method reset (line 33) | def reset(self, **kwargs) -> AtariResetReturn:
method step (line 37) | def step(self, action: int) -> AtariStepReturn:
class NoopResetEnv (line 43) | class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
method __init__ (line 52) | def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
method reset (line 59) | def reset(self, **kwargs) -> AtariResetReturn:
class FireResetEnv (line 75) | class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
method __init__ (line 82) | def __init__(self, env: gym.Env) -> None:
method reset (line 87) | def reset(self, **kwargs) -> AtariResetReturn:
class EpisodicLifeEnv (line 98) | class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
method __init__ (line 114) | def __init__(self, env: gym.Env) -> None:
method step (line 119) | def step(self, action: int) -> AtariStepReturn:
method reset (line 133) | def reset(self, **kwargs) -> AtariResetReturn:
class MaxAndSkipEnv (line 157) | class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
method __init__ (line 167) | def __init__(self, env: gym.Env, skip: int = 4) -> None:
method step (line 175) | def step(self, action: int) -> AtariStepReturn:
class ClipRewardEnv (line 202) | class ClipRewardEnv(gym.RewardWrapper):
method __init__ (line 209) | def __init__(self, env: gym.Env) -> None:
method reward (line 212) | def reward(self, reward: SupportsFloat) -> float:
class WarpFrame (line 222) | class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]):
method __init__ (line 232) | def __init__(self, env: gym.Env, width: int = 84, height: int = 84) ->...
method observation (line 245) | def observation(self, frame: np.ndarray) -> np.ndarray:
class AtariWrapper (line 258) | class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
method __init__ (line 289) | def __init__(
FILE: stable_baselines3/common/base_class.py
function maybe_make_env (line 48) | def maybe_make_env(env: GymEnv | str, verbose: int) -> GymEnv:
class BaseAlgorithm (line 67) | class BaseAlgorithm(ABC):
method __init__ (line 106) | def __init__(
method _wrap_env (line 204) | def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = T...
method _setup_model (line 252) | def _setup_model(self) -> None:
method set_logger (line 255) | def set_logger(self, logger: Logger) -> None:
method logger (line 270) | def logger(self) -> Logger:
method _setup_lr_schedule (line 274) | def _setup_lr_schedule(self) -> None:
method _update_current_progress_remaining (line 278) | def _update_current_progress_remaining(self, num_timesteps: int, total...
method _update_learning_rate (line 287) | def _update_learning_rate(self, optimizers: list[th.optim.Optimizer] |...
method _excluded_save_params (line 303) | def _excluded_save_params(self) -> list[str]:
method _get_policy_from_name (line 323) | def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]:
method _get_torch_save_params (line 340) | def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
method _init_callback (line 358) | def _init_callback(
method _setup_learn (line 383) | def _setup_learn(
method _update_info_buffer (line 438) | def _update_info_buffer(self, infos: list[dict[str, Any]], dones: np.n...
method get_env (line 459) | def get_env(self) -> VecEnv | None:
method get_vec_normalize_env (line 467) | def get_vec_normalize_env(self) -> VecNormalize | None:
method set_env (line 476) | def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
method learn (line 512) | def learn(
method predict (line 537) | def predict(
method set_random_seed (line 559) | def set_random_seed(self, seed: int | None = None) -> None:
method set_parameters (line 574) | def set_parameters(
method load (line 643) | def load( # noqa: C901
method get_parameters (line 804) | def get_parameters(self) -> dict[str, dict]:
method save (line 819) | def save(
method dump_logs (line 869) | def dump_logs(self) -> None:
method _dump_logs (line 875) | def _dump_logs(self, *args) -> None:
FILE: stable_baselines3/common/buffers.py
class BaseBuffer (line 27) | class BaseBuffer(ABC):
method __init__ (line 42) | def __init__(
method swap_and_flatten (line 63) | def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
method size (line 77) | def size(self) -> int:
method add (line 85) | def add(self, *args, **kwargs) -> None:
method extend (line 91) | def extend(self, *args, **kwargs) -> None:
method reset (line 99) | def reset(self) -> None:
method sample (line 106) | def sample(self, batch_size: int, env: VecNormalize | None = None):
method _get_samples (line 118) | def _get_samples(
method to_torch (line 128) | def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
method _normalize_obs (line 143) | def _normalize_obs(
method _normalize_reward (line 152) | def _normalize_reward(reward: np.ndarray, env: VecNormalize | None = N...
class ReplayBuffer (line 158) | class ReplayBuffer(BaseBuffer):
method __init__ (line 185) | def __init__(
method add (line 247) | def add(
method sample (line 285) | def sample(self, batch_size: int, env: VecNormalize | None = None) -> ...
method _get_samples (line 307) | def _get_samples(self, batch_inds: np.ndarray, env: VecNormalize | Non...
method _maybe_cast_dtype (line 328) | def _maybe_cast_dtype(dtype: np.typing.DTypeLike | None) -> np.typing....
class RolloutBuffer (line 343) | class RolloutBuffer(BaseBuffer):
method __init__ (line 375) | def __init__(
method reset (line 391) | def reset(self) -> None:
method compute_returns_and_advantage (line 403) | def compute_returns_and_advantage(self, last_values: th.Tensor, dones:...
method add (line 440) | def add(
method get (line 481) | def get(self, batch_size: int | None = None) -> Generator[RolloutBuffe...
method _get_samples (line 508) | def _get_samples(
class DictReplayBuffer (line 525) | class DictReplayBuffer(ReplayBuffer):
method __init__ (line 547) | def __init__(
method add (line 612) | def add( # type: ignore[override]
method sample (line 649) | def sample( # type: ignore[override]
method _get_samples (line 664) | def _get_samples( # type: ignore[override]
class DictRolloutBuffer (line 697) | class DictRolloutBuffer(RolloutBuffer):
method __init__ (line 726) | def __init__(
method reset (line 746) | def reset(self) -> None:
method add (line 762) | def add( # type: ignore[override]
method get (line 805) | def get( # type: ignore[override]
method _get_samples (line 831) | def _get_samples( # type: ignore[override]
class NStepReplayBuffer (line 847) | class NStepReplayBuffer(ReplayBuffer):
method __init__ (line 879) | def __init__(self, *args, n_steps: int = 3, gamma: float = 0.99, **kwa...
method _get_samples (line 886) | def _get_samples(self, batch_inds: np.ndarray, env: VecNormalize | Non...
FILE: stable_baselines3/common/callbacks.py
class BaseCallback (line 31) | class BaseCallback(ABC):
method __init__ (line 42) | def __init__(self, verbose: int = 0):
method training_env (line 56) | def training_env(self) -> VecEnv:
method logger (line 64) | def logger(self) -> Logger:
method init_callback (line 68) | def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
method _init_callback (line 76) | def _init_callback(self) -> None:
method on_training_start (line 79) | def on_training_start(self, locals_: dict[str, Any], globals_: dict[st...
method _on_training_start (line 87) | def _on_training_start(self) -> None:
method on_rollout_start (line 90) | def on_rollout_start(self) -> None:
method _on_rollout_start (line 93) | def _on_rollout_start(self) -> None:
method _on_step (line 97) | def _on_step(self) -> bool:
method on_step (line 103) | def on_step(self) -> bool:
method on_training_end (line 117) | def on_training_end(self) -> None:
method _on_training_end (line 120) | def _on_training_end(self) -> None:
method on_rollout_end (line 123) | def on_rollout_end(self) -> None:
method _on_rollout_end (line 126) | def _on_rollout_end(self) -> None:
method update_locals (line 129) | def update_locals(self, locals_: dict[str, Any]) -> None:
method update_child_locals (line 138) | def update_child_locals(self, locals_: dict[str, Any]) -> None:
class EventCallback (line 147) | class EventCallback(BaseCallback):
method __init__ (line 156) | def __init__(self, callback: BaseCallback | None = None, verbose: int ...
method init_callback (line 164) | def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
method _on_training_start (line 169) | def _on_training_start(self) -> None:
method _on_event (line 173) | def _on_event(self) -> bool:
method _on_step (line 178) | def _on_step(self) -> bool:
method update_child_locals (line 181) | def update_child_locals(self, locals_: dict[str, Any]) -> None:
class CallbackList (line 191) | class CallbackList(BaseCallback):
method __init__ (line 199) | def __init__(self, callbacks: list[BaseCallback]):
method _init_callback (line 204) | def _init_callback(self) -> None:
method _on_training_start (line 212) | def _on_training_start(self) -> None:
method _on_rollout_start (line 216) | def _on_rollout_start(self) -> None:
method _on_step (line 220) | def _on_step(self) -> bool:
method _on_rollout_end (line 227) | def _on_rollout_end(self) -> None:
method _on_training_end (line 231) | def _on_training_end(self) -> None:
method update_child_locals (line 235) | def update_child_locals(self, locals_: dict[str, Any]) -> None:
class CheckpointCallback (line 245) | class CheckpointCallback(BaseCallback):
method __init__ (line 268) | def __init__(
method _init_callback (line 284) | def _init_callback(self) -> None:
method _checkpoint_path (line 289) | def _checkpoint_path(self, checkpoint_type: str = "", extension: str =...
method _on_step (line 300) | def _on_step(self) -> bool:
class ConvertCallback (line 324) | class ConvertCallback(BaseCallback):
method __init__ (line 332) | def __init__(self, callback: Callable[[dict[str, Any], dict[str, Any]]...
method _on_step (line 336) | def _on_step(self) -> bool:
class EvalCallback (line 342) | class EvalCallback(EventCallback):
method __init__ (line 370) | def __init__(
method _init_callback (line 416) | def _init_callback(self) -> None:
method _log_success_callback (line 431) | def _log_success_callback(self, locals_: dict[str, Any], globals_: dic...
method _on_step (line 447) | def _on_step(self) -> bool:
method update_child_locals (line 534) | def update_child_locals(self, locals_: dict[str, Any]) -> None:
class StopTrainingOnRewardThreshold (line 544) | class StopTrainingOnRewardThreshold(BaseCallback):
method __init__ (line 559) | def __init__(self, reward_threshold: float, verbose: int = 0):
method _on_step (line 563) | def _on_step(self) -> bool:
class EveryNTimesteps (line 574) | class EveryNTimesteps(EventCallback):
method __init__ (line 583) | def __init__(self, n_steps: int, callback: BaseCallback):
method _on_step (line 588) | def _on_step(self) -> bool:
class LogEveryNTimesteps (line 595) | class LogEveryNTimesteps(EveryNTimesteps):
method __init__ (line 602) | def __init__(self, n_steps: int):
method _log_data (line 605) | def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any])...
class StopTrainingOnMaxEpisodes (line 610) | class StopTrainingOnMaxEpisodes(BaseCallback):
method __init__ (line 622) | def __init__(self, max_episodes: int, verbose: int = 0):
method _init_callback (line 628) | def _init_callback(self) -> None:
method _on_step (line 632) | def _on_step(self) -> bool:
class StopTrainingOnNoModelImprovement (line 654) | class StopTrainingOnNoModelImprovement(BaseCallback):
method __init__ (line 669) | def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, ...
method _on_step (line 676) | def _on_step(self) -> bool:
class ProgressBarCallback (line 699) | class ProgressBarCallback(BaseCallback):
method __init__ (line 707) | def __init__(self) -> None:
method _on_training_start (line 716) | def _on_training_start(self) -> None:
method _on_step (line 721) | def _on_step(self) -> bool:
method _on_training_end (line 726) | def _on_training_end(self) -> None:
FILE: stable_baselines3/common/distributions.py
class Distribution (line 26) | class Distribution(ABC):
method __init__ (line 31) | def __init__(self):
method proba_distribution_net (line 35) | def proba_distribution_net(self, *args, **kwargs) -> nn.Module | tuple...
method proba_distribution (line 42) | def proba_distribution(self: SelfDistribution, *args, **kwargs) -> Sel...
method log_prob (line 49) | def log_prob(self, actions: th.Tensor) -> th.Tensor:
method entropy (line 58) | def entropy(self) -> th.Tensor | None:
method sample (line 66) | def sample(self) -> th.Tensor:
method mode (line 74) | def mode(self) -> th.Tensor:
method get_actions (line 82) | def get_actions(self, deterministic: bool = False) -> th.Tensor:
method actions_from_params (line 94) | def actions_from_params(self, *args, **kwargs) -> th.Tensor:
method log_prob_from_params (line 103) | def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th...
function sum_independent_dims (line 112) | def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
class DiagGaussianDistribution (line 127) | class DiagGaussianDistribution(Distribution):
method __init__ (line 136) | def __init__(self, action_dim: int):
method proba_distribution_net (line 140) | def proba_distribution_net(self, latent_dim: int, log_std_init: float ...
method proba_distribution (line 155) | def proba_distribution(
method log_prob (line 169) | def log_prob(self, actions: th.Tensor) -> th.Tensor:
method entropy (line 180) | def entropy(self) -> th.Tensor | None:
method sample (line 183) | def sample(self) -> th.Tensor:
method mode (line 187) | def mode(self) -> th.Tensor:
method actions_from_params (line 190) | def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Ten...
method log_prob_from_params (line 195) | def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Te...
class SquashedDiagGaussianDistribution (line 209) | class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
method __init__ (line 217) | def __init__(self, action_dim: int, epsilon: float = 1e-6):
method proba_distribution (line 223) | def proba_distribution(
method log_prob (line 229) | def log_prob(self, actions: th.Tensor, gaussian_actions: th.Tensor | N...
method entropy (line 244) | def entropy(self) -> th.Tensor | None:
method sample (line 249) | def sample(self) -> th.Tensor:
method mode (line 254) | def mode(self) -> th.Tensor:
method log_prob_from_params (line 259) | def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Te...
class CategoricalDistribution (line 265) | class CategoricalDistribution(Distribution):
method __init__ (line 274) | def __init__(self, action_dim: int):
method proba_distribution_net (line 278) | def proba_distribution_net(self, latent_dim: int) -> nn.Module:
method proba_distribution (line 291) | def proba_distribution(self: SelfCategoricalDistribution, action_logit...
method log_prob (line 295) | def log_prob(self, actions: th.Tensor) -> th.Tensor:
method entropy (line 298) | def entropy(self) -> th.Tensor:
method sample (line 301) | def sample(self) -> th.Tensor:
method mode (line 304) | def mode(self) -> th.Tensor:
method actions_from_params (line 307) | def actions_from_params(self, action_logits: th.Tensor, deterministic:...
method log_prob_from_params (line 312) | def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.T...
class MultiCategoricalDistribution (line 318) | class MultiCategoricalDistribution(Distribution):
method __init__ (line 327) | def __init__(self, action_dims: list[int]):
method proba_distribution_net (line 331) | def proba_distribution_net(self, latent_dim: int) -> nn.Module:
method proba_distribution (line 345) | def proba_distribution(
method log_prob (line 351) | def log_prob(self, actions: th.Tensor) -> th.Tensor:
method entropy (line 357) | def entropy(self) -> th.Tensor:
method sample (line 360) | def sample(self) -> th.Tensor:
method mode (line 363) | def mode(self) -> th.Tensor:
method actions_from_params (line 366) | def actions_from_params(self, action_logits: th.Tensor, deterministic:...
method log_prob_from_params (line 371) | def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.T...
class BernoulliDistribution (line 377) | class BernoulliDistribution(Distribution):
method __init__ (line 386) | def __init__(self, action_dims: int):
method proba_distribution_net (line 390) | def proba_distribution_net(self, latent_dim: int) -> nn.Module:
method proba_distribution (line 402) | def proba_distribution(self: SelfBernoulliDistribution, action_logits:...
method log_prob (line 406) | def log_prob(self, actions: th.Tensor) -> th.Tensor:
method entropy (line 409) | def entropy(self) -> th.Tensor:
method sample (line 412) | def sample(self) -> th.Tensor:
method mode (line 415) | def mode(self) -> th.Tensor:
method actions_from_params (line 418) | def actions_from_params(self, action_logits: th.Tensor, deterministic:...
method log_prob_from_params (line 423) | def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.T...
class StateDependentNoiseDistribution (line 429) | class StateDependentNoiseDistribution(Distribution):
method __init__ (line 459) | def __init__(
method get_std (line 480) | def get_std(self, log_std: th.Tensor) -> th.Tensor:
method sample_weights (line 506) | def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> N...
method proba_distribution_net (line 521) | def proba_distribution_net(
method proba_distribution (line 548) | def proba_distribution(
method log_prob (line 565) | def log_prob(self, actions: th.Tensor) -> th.Tensor:
method entropy (line 580) | def entropy(self) -> th.Tensor | None:
method sample (line 587) | def sample(self) -> th.Tensor:
method mode (line 594) | def mode(self) -> th.Tensor:
method get_noise (line 600) | def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
method actions_from_params (line 612) | def actions_from_params(
method log_prob_from_params (line 619) | def log_prob_from_params(
class TanhBijector (line 627) | class TanhBijector:
method __init__ (line 635) | def __init__(self, epsilon: float = 1e-6):
method forward (line 640) | def forward(x: th.Tensor) -> th.Tensor:
method atanh (line 644) | def atanh(x: th.Tensor) -> th.Tensor:
method inverse (line 654) | def inverse(y: th.Tensor) -> th.Tensor:
method log_prob_correction (line 665) | def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
function make_proba_distribution (line 670) | def make_proba_distribution(
function kl_divergence (line 705) | def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> t...
FILE: stable_baselines3/common/env_checker.py
function _is_oneof_space (line 12) | def _is_oneof_space(space: spaces.Space) -> bool:
function _is_numpy_array_space (line 24) | def _is_numpy_array_space(space: spaces.Space) -> bool:
function _starts_at_zero (line 32) | def _starts_at_zero(space: spaces.Discrete | spaces.MultiDiscrete) -> bool:
function _check_non_zero_start (line 39) | def _check_non_zero_start(space: spaces.Space, space_type: str = "observ...
function _check_image_input (line 57) | def _check_image_input(observation_space: spaces.Box, key: str = "") -> ...
function _check_unsupported_spaces (line 95) | def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Sp...
function _check_nan (line 190) | def _check_nan(env: gym.Env) -> None:
function _is_goal_env (line 199) | def _is_goal_env(env: gym.Env) -> bool:
function _check_goal_env_obs (line 207) | def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, metho...
function _check_goal_env_compute_reward (line 227) | def _check_goal_env_compute_reward(
function _check_obs (line 254) | def _check_obs(obs: tuple | dict | np.ndarray | int, observation_space: ...
function _check_box_obs (line 311) | def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
function _check_returned_values (line 331) | def _check_returned_values(env: gym.Env, observation_space: spaces.Space...
function _check_spaces (line 413) | def _check_spaces(env: gym.Env) -> None:
function _check_render (line 442) | def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: ...
function check_env (line 467) | def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool =...
FILE: stable_baselines3/common/env_util.py
function unwrap_wrapper (line 13) | def unwrap_wrapper(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> gy...
function is_wrapped (line 29) | def is_wrapped(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> bool:
function make_vec_env (line 40) | def make_vec_env(
function make_atari_env (line 132) | def make_atari_env(
FILE: stable_baselines3/common/envs/bit_flipping_env.py
class BitFlippingEnv (line 11) | class BitFlippingEnv(Env):
method __init__ (line 33) | def __init__(
method seed (line 68) | def seed(self, seed: int) -> None:
method convert_if_needed (line 71) | def convert_if_needed(self, state: np.ndarray) -> int | np.ndarray:
method convert_to_bit_vector (line 92) | def convert_to_bit_vector(self, state: int | np.ndarray, batch_size: i...
method _make_observation_space (line 111) | def _make_observation_space(self, discrete_obs_space: bool, image_obs_...
method _get_obs (line 169) | def _get_obs(self) -> dict[str, int | np.ndarray]:
method reset (line 183) | def reset(self, *, seed: int | None = None, options: dict | None = Non...
method step (line 190) | def step(self, action: np.ndarray | int) -> GymStepReturn:
method compute_reward (line 210) | def compute_reward(
method render (line 229) | def render(self) -> np.ndarray | None: # type: ignore[override]
method close (line 235) | def close(self) -> None:
FILE: stable_baselines3/common/envs/identity_env.py
class IdentityEnv (line 12) | class IdentityEnv(gym.Env, Generic[T]):
method __init__ (line 13) | def __init__(self, dim: int | None = None, space: spaces.Space | None ...
method reset (line 37) | def reset(self, *, seed: int | None = None, options: dict | None = Non...
method step (line 45) | def step(self, action: T) -> tuple[T, float, bool, bool, dict[str, Any]]:
method _choose_next_state (line 53) | def _choose_next_state(self) -> None:
method _get_reward (line 56) | def _get_reward(self, action: T) -> float:
method render (line 59) | def render(self, mode: str = "human") -> None:
class IdentityEnvBox (line 63) | class IdentityEnvBox(IdentityEnv[np.ndarray]):
method __init__ (line 64) | def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = ...
method step (line 77) | def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, b...
method _get_reward (line 85) | def _get_reward(self, action: np.ndarray) -> float:
class IdentityEnvMultiDiscrete (line 89) | class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
method __init__ (line 90) | def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
class IdentityEnvMultiBinary (line 101) | class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
method __init__ (line 102) | def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
class FakeImageEnv (line 113) | class FakeImageEnv(gym.Env):
method __init__ (line 125) | def __init__(
method reset (line 145) | def reset(self, *, seed: int | None = None, options: dict | None = Non...
method step (line 151) | def step(self, action: np.ndarray | int) -> GymStepReturn:
method render (line 158) | def render(self, mode: str = "human") -> None:
FILE: stable_baselines3/common/envs/multi_input_envs.py
class SimpleMultiObsEnv (line 8) | class SimpleMultiObsEnv(gym.Env):
method __init__ (line 36) | def __init__(
method init_state_mapping (line 79) | def init_state_mapping(self, num_col: int, num_row: int) -> None:
method get_state_mapping (line 95) | def get_state_mapping(self) -> dict[str, np.ndarray]:
method init_possible_transitions (line 103) | def init_possible_transitions(self) -> None:
method step (line 122) | def step(self, action: int | np.ndarray) -> GymStepReturn:
method render (line 159) | def render(self, mode: str = "human") -> None:
method reset (line 167) | def reset(self, *, seed: int | None = None, options: dict | None = Non...
FILE: stable_baselines3/common/evaluation.py
function evaluate_policy (line 12) | def evaluate_policy(
FILE: stable_baselines3/common/logger.py
class Video (line 35) | class Video:
method __init__ (line 43) | def __init__(self, frames: th.Tensor, fps: float):
class Figure (line 48) | class Figure:
method __init__ (line 56) | def __init__(self, figure: matplotlib.figure.Figure, close: bool):
class Image (line 61) | class Image:
method __init__ (line 71) | def __init__(self, image: th.Tensor | np.ndarray | str, dataformats: s...
class HParam (line 76) | class HParam:
method __init__ (line 85) | def __init__(self, hparam_dict: Mapping[str, bool | str | float | None...
class FormatUnsupportedError (line 92) | class FormatUnsupportedError(NotImplementedError):
method __init__ (line 102) | def __init__(self, unsupported_formats: Sequence[str], value_descripti...
class KVWriter (line 113) | class KVWriter:
method write (line 118) | def write(self, key_values: dict[str, Any], key_excluded: dict[str, tu...
method close (line 128) | def close(self) -> None:
class SeqWriter (line 135) | class SeqWriter:
method write_sequence (line 140) | def write_sequence(self, sequence: list[str]) -> None:
class HumanOutputFormat (line 149) | class HumanOutputFormat(KVWriter, SeqWriter):
method __init__ (line 163) | def __init__(self, filename_or_file: str | TextIO, max_length: int = 36):
method write (line 176) | def write(self, key_values: dict[str, Any], key_excluded: dict[str, tu...
method _truncate (line 243) | def _truncate(self, string: str) -> str:
method write_sequence (line 248) | def write_sequence(self, sequence: list[str]) -> None:
method close (line 256) | def close(self) -> None:
function filter_excluded_keys (line 264) | def filter_excluded_keys(key_values: dict[str, Any], key_excluded: dict[...
class JSONOutputFormat (line 280) | class JSONOutputFormat(KVWriter):
method __init__ (line 287) | def __init__(self, filename: str):
method write (line 290) | def write(self, key_values: dict[str, Any], key_excluded: dict[str, tu...
method close (line 316) | def close(self) -> None:
class CSVOutputFormat (line 324) | class CSVOutputFormat(KVWriter):
method __init__ (line 331) | def __init__(self, filename: str):
method write (line 337) | def write(self, key_values: dict[str, Any], key_excluded: dict[str, tu...
method close (line 384) | def close(self) -> None:
class TensorBoardOutputFormat (line 391) | class TensorBoardOutputFormat(KVWriter):
method __init__ (line 398) | def __init__(self, folder: str):
method write (line 403) | def write(self, key_values: dict[str, Any], key_excluded: dict[str, tu...
method close (line 439) | def close(self) -> None:
function make_output_format (line 448) | def make_output_format(_format: str, log_dir: str, log_suffix: str = "")...
class Logger (line 477) | class Logger:
method __init__ (line 485) | def __init__(self, folder: str | None, output_formats: list[KVWriter]):
method to_tuple (line 494) | def to_tuple(string_or_tuple: str | tuple[str, ...] | None) -> tuple[s...
method record (line 504) | def record(self, key: str, value: Any, exclude: str | tuple[str, ...] ...
method record_mean (line 517) | def record_mean(self, key: str, value: float | None, exclude: str | tu...
method dump (line 532) | def dump(self, step: int = 0) -> None:
method log (line 546) | def log(self, *args, level: int = INFO) -> None:
method debug (line 560) | def debug(self, *args) -> None:
method info (line 570) | def info(self, *args) -> None:
method warn (line 580) | def warn(self, *args) -> None:
method error (line 590) | def error(self, *args) -> None:
method set_level (line 602) | def set_level(self, level: int) -> None:
method get_dir (line 610) | def get_dir(self) -> str | None:
method close (line 619) | def close(self) -> None:
method _do_log (line 628) | def _do_log(self, args: tuple[Any, ...]) -> None:
function configure (line 639) | def configure(folder: str | None = None, format_strings: list[str] | Non...
function read_json (line 675) | def read_json(filename: str) -> pandas.DataFrame:
function read_csv (line 689) | def read_csv(filename: str) -> pandas.DataFrame:
FILE: stable_baselines3/common/monitor.py
class Monitor (line 15) | class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
method __init__ (line 31) | def __init__(
method reset (line 64) | def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]:
method step (line 85) | def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool,...
method close (line 113) | def close(self) -> None:
method get_total_steps (line 121) | def get_total_steps(self) -> int:
method get_episode_rewards (line 129) | def get_episode_rewards(self) -> list[float]:
method get_episode_lengths (line 137) | def get_episode_lengths(self) -> list[int]:
method get_episode_times (line 145) | def get_episode_times(self) -> list[float]:
class LoadMonitorResultsError (line 154) | class LoadMonitorResultsError(Exception):
class ResultsWriter (line 162) | class ResultsWriter:
method __init__ (line 175) | def __init__(
method write_row (line 203) | def write_row(self, epinfo: dict[str, float]) -> None:
method close (line 213) | def close(self) -> None:
function get_monitor_files (line 220) | def get_monitor_files(path: str) -> list[str]:
function load_results (line 230) | def load_results(path: str) -> pandas.DataFrame:
FILE: stable_baselines3/common/noise.py
class ActionNoise (line 9) | class ActionNoise(ABC):
method __init__ (line 14) | def __init__(self) -> None:
method reset (line 17) | def reset(self) -> None:
method __call__ (line 24) | def __call__(self) -> np.ndarray:
class NormalActionNoise (line 28) | class NormalActionNoise(ActionNoise):
method __init__ (line 37) | def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLi...
method __call__ (line 43) | def __call__(self) -> np.ndarray:
method __repr__ (line 46) | def __repr__(self) -> str:
class OrnsteinUhlenbeckActionNoise (line 50) | class OrnsteinUhlenbeckActionNoise(ActionNoise):
method __init__ (line 64) | def __init__(
method __call__ (line 83) | def __call__(self) -> np.ndarray:
method reset (line 92) | def reset(self) -> None:
method __repr__ (line 98) | def __repr__(self) -> str:
class VectorizedActionNoise (line 102) | class VectorizedActionNoise(ActionNoise):
method __init__ (line 110) | def __init__(self, base_noise: ActionNoise, n_envs: int) -> None:
method reset (line 120) | def reset(self, indices: Iterable[int] | None = None) -> None:
method __repr__ (line 133) | def __repr__(self) -> str:
method __call__ (line 136) | def __call__(self) -> np.ndarray:
method base_noise (line 144) | def base_noise(self) -> ActionNoise:
method base_noise (line 148) | def base_noise(self, base_noise: ActionNoise) -> None:
method noises (line 156) | def noises(self) -> list[ActionNoise]:
method noises (line 160) | def noises(self, noises: list[ActionNoise]) -> None:
FILE: stable_baselines3/common/off_policy_algorithm.py
class OffPolicyAlgorithm (line 27) | class OffPolicyAlgorithm(BaseAlgorithm):
method __init__ (line 81) | def __init__(
method _convert_train_freq (line 150) | def _convert_train_freq(self) -> None:
method _setup_model (line 174) | def _setup_model(self) -> None:
method save_replay_buffer (line 217) | def save_replay_buffer(self, path: str | pathlib.Path | io.BufferedIOB...
method load_replay_buffer (line 227) | def load_replay_buffer(
method _setup_learn (line 259) | def _setup_learn(
method learn (line 312) | def learn(
method train (line 360) | def train(self, gradient_steps: int, batch_size: int) -> None:
method _sample_action (line 367) | def _sample_action(
method dump_logs (line 417) | def dump_logs(self) -> None:
method _on_step (line 441) | def _on_step(self) -> None:
method _store_transition (line 449) | def _store_transition(
method collect_rollouts (line 514) | def collect_rollouts(
FILE: stable_baselines3/common/on_policy_algorithm.py
class OnPolicyAlgorithm (line 21) | class OnPolicyAlgorithm(BaseAlgorithm):
method __init__ (line 61) | def __init__(
method _setup_model (line 115) | def _setup_model(self) -> None:
method _maybe_recommend_cpu (line 142) | def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolic...
method collect_rollouts (line 162) | def collect_rollouts(
method train (line 270) | def train(self) -> None:
method dump_logs (line 277) | def dump_logs(self, iteration: int = 0) -> None:
method learn (line 300) | def learn(
method _get_torch_save_params (line 343) | def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
FILE: stable_baselines3/common/policies.py
class BaseModel (line 39) | class BaseModel(nn.Module):
method __init__ (line 63) | def __init__(
method _update_features_extractor (line 96) | def _update_features_extractor(
method make_features_extractor (line 118) | def make_features_extractor(self) -> BaseFeaturesExtractor:
method extract_features (line 122) | def extract_features(self, obs: PyTorchObs, features_extractor: BaseFe...
method _get_constructor_parameters (line 133) | def _get_constructor_parameters(self) -> dict[str, Any]:
method device (line 149) | def device(self) -> th.device:
method save (line 158) | def save(self, path: str) -> None:
method load (line 167) | def load(cls: type[SelfBaseModel], path: str, device: th.device | str ...
method load_from_vector (line 187) | def load_from_vector(self, vector: np.ndarray) -> None:
method parameters_to_vector (line 195) | def parameters_to_vector(self) -> np.ndarray:
method set_training_mode (line 203) | def set_training_mode(self, mode: bool) -> None:
method is_vectorized_observation (line 213) | def is_vectorized_observation(self, observation: np.ndarray | dict[str...
method obs_to_tensor (line 236) | def obs_to_tensor(self, observation: np.ndarray | dict[str, np.ndarray...
class BasePolicy (line 280) | class BasePolicy(BaseModel, ABC):
method __init__ (line 293) | def __init__(self, *args, squash_output: bool = False, **kwargs):
method _dummy_schedule (line 298) | def _dummy_schedule(progress_remaining: float) -> float:
method squash_output (line 304) | def squash_output(self) -> bool:
method init_weights (line 309) | def init_weights(module: nn.Module, gain: float = 1) -> None:
method _predict (line 319) | def _predict(self, observation: PyTorchObs, deterministic: bool = Fals...
method predict (line 331) | def predict(
method scale_action (line 388) | def scale_action(self, action: np.ndarray) -> np.ndarray:
method unscale_action (line 402) | def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
class ActorCriticPolicy (line 416) | class ActorCriticPolicy(BasePolicy):
method __init__ (line 448) | def __init__(
method _get_constructor_parameters (line 537) | def _get_constructor_parameters(self) -> dict[str, Any]:
method reset_noise (line 561) | def reset_noise(self, n_envs: int = 1) -> None:
method _build_mlp_extractor (line 570) | def _build_mlp_extractor(self) -> None:
method _build (line 585) | def _build(self, lr_schedule: Schedule) -> None:
method forward (line 636) | def forward(self, obs: th.Tensor, deterministic: bool = False) -> tupl...
method extract_features (line 660) | def extract_features( # type: ignore[override]
method _get_action_dist_from_latent (line 684) | def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distri...
method _predict (line 709) | def _predict(self, observation: PyTorchObs, deterministic: bool = Fals...
method evaluate_actions (line 719) | def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tup...
method get_distribution (line 743) | def get_distribution(self, obs: PyTorchObs) -> Distribution:
method predict_values (line 754) | def predict_values(self, obs: PyTorchObs) -> th.Tensor:
class ActorCriticCnnPolicy (line 766) | class ActorCriticCnnPolicy(ActorCriticPolicy):
method __init__ (line 798) | def __init__(
class MultiInputActorCriticPolicy (line 839) | class MultiInputActorCriticPolicy(ActorCriticPolicy):
method __init__ (line 871) | def __init__(
class ContinuousCritic (line 912) | class ContinuousCritic(BaseModel):
method __init__ (line 941) | def __init__(
method forward (line 971) | def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tens...
method q1_forward (line 979) | def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
FILE: stable_baselines3/common/preprocessing.py
function is_image_space_channels_first (line 9) | def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
function is_image_space (line 26) | def is_image_space(
function maybe_transpose (line 71) | def maybe_transpose(observation: np.ndarray, observation_space: spaces.S...
function preprocess_obs (line 91) | def preprocess_obs(
function get_obs_shape (line 142) | def get_obs_shape(
function get_flattened_obs_dim (line 169) | def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
function get_action_dim (line 188) | def get_action_dim(action_space: spaces.Space) -> int:
function check_for_nested_spaces (line 213) | def check_for_nested_spaces(obs_space: spaces.Space) -> None:
FILE: stable_baselines3/common/results_plotter.py
function rolling_window (line 19) | def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
function window_func (line 32) | def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func:...
function ts2xy (line 47) | def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np...
function plot_curves (line 72) | def plot_curves(
function plot_results (line 102) | def plot_results(
FILE: stable_baselines3/common/running_mean_std.py
class RunningMeanStd (line 4) | class RunningMeanStd:
method __init__ (line 5) | def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()):
method copy (line 17) | def copy(self) -> "RunningMeanStd":
method combine (line 27) | def combine(self, other: "RunningMeanStd") -> None:
method update (line 35) | def update(self, arr: np.ndarray) -> None:
method update_from_moments (line 41) | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.nd...
FILE: stable_baselines3/common/save_util.py
function recursive_getattr (line 25) | def recursive_getattr(obj: Any, attr: str, *args) -> Any:
function recursive_setattr (line 44) | def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
function is_json_serializable (line 60) | def is_json_serializable(item: Any) -> bool:
function data_to_json (line 76) | def data_to_json(data: dict[str, Any]) -> str:
function json_to_data (line 131) | def json_to_data(json_string: str, custom_objects: dict[str, Any] | None...
function open_path (line 182) | def open_path(
function open_path_str (line 226) | def open_path_str(path: str, mode: str, verbose: int = 0, suffix: str | ...
function open_path_pathlib (line 244) | def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, s...
function save_to_zip_file (line 294) | def save_to_zip_file(
function save_to_pkl (line 339) | def save_to_pkl(path: str | pathlib.Path | io.BufferedIOBase, obj: Any, ...
function load_from_pkl (line 359) | def load_from_pkl(path: str | pathlib.Path | io.BufferedIOBase, verbose:...
function load_from_zip_file (line 376) | def load_from_zip_file(
FILE: stable_baselines3/common/sb2_compat/rmsprop_tf_like.py
class RMSpropTFLike (line 8) | class RMSpropTFLike(Optimizer):
method __init__ (line 47) | def __init__(
method __setstate__ (line 71) | def __setstate__(self, state: dict[str, Any]) -> None:
method step (line 78) | def step(self, closure: Callable[[], float] | None = None) -> float | ...
FILE: stable_baselines3/common/torch_layers.py
class BaseFeaturesExtractor (line 11) | class BaseFeaturesExtractor(nn.Module):
method __init__ (line 19) | def __init__(self, observation_space: gym.Space, features_dim: int = 0...
method features_dim (line 26) | def features_dim(self) -> int:
class FlattenExtractor (line 31) | class FlattenExtractor(BaseFeaturesExtractor):
method __init__ (line 39) | def __init__(self, observation_space: gym.Space) -> None:
method forward (line 43) | def forward(self, observations: th.Tensor) -> th.Tensor:
class NatureCNN (line 47) | class NatureCNN(BaseFeaturesExtractor):
method __init__ (line 63) | def __init__(
method forward (line 104) | def forward(self, observations: th.Tensor) -> th.Tensor:
function create_mlp (line 108) | def create_mlp(
class MlpExtractor (line 184) | class MlpExtractor(nn.Module):
method __init__ (line 209) | def __init__(
method forward (line 250) | def forward(self, features: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
method forward_actor (line 257) | def forward_actor(self, features: th.Tensor) -> th.Tensor:
method forward_critic (line 260) | def forward_critic(self, features: th.Tensor) -> th.Tensor:
class CombinedExtractor (line 264) | class CombinedExtractor(BaseFeaturesExtractor):
method __init__ (line 280) | def __init__(
method forward (line 306) | def forward(self, observations: TensorDict) -> th.Tensor:
function get_actor_critic_arch (line 314) | def get_actor_critic_arch(net_arch: list[int] | dict[str, list[int]]) ->...
FILE: stable_baselines3/common/type_aliases.py
class RolloutBufferSamples (line 32) | class RolloutBufferSamples(NamedTuple):
class DictRolloutBufferSamples (line 41) | class DictRolloutBufferSamples(NamedTuple):
class ReplayBufferSamples (line 50) | class ReplayBufferSamples(NamedTuple):
class DictReplayBufferSamples (line 60) | class DictReplayBufferSamples(NamedTuple):
class RolloutReturn (line 69) | class RolloutReturn(NamedTuple):
class TrainFrequencyUnit (line 75) | class TrainFrequencyUnit(Enum):
class TrainFreq (line 80) | class TrainFreq(NamedTuple):
class PolicyPredictor (line 85) | class PolicyPredictor(Protocol):
method predict (line 86) | def predict(
FILE: stable_baselines3/common/utils.py
function set_random_seed (line 28) | def set_random_seed(seed: int, using_cuda: bool = False) -> None:
function explained_variance (line 49) | def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
function update_learning_rate (line 68) | def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: f...
class FloatSchedule (line 80) | class FloatSchedule:
method __init__ (line 89) | def __init__(self, value_schedule: Schedule | float):
method __call__ (line 98) | def __call__(self, progress_remaining: float) -> float:
method __repr__ (line 103) | def __repr__(self) -> str:
class LinearSchedule (line 107) | class LinearSchedule:
method __init__ (line 120) | def __init__(self, start: float, end: float, end_fraction: float) -> N...
method __call__ (line 125) | def __call__(self, progress_remaining: float) -> float:
method __repr__ (line 131) | def __repr__(self) -> str:
class ConstantSchedule (line 135) | class ConstantSchedule:
method __init__ (line 143) | def __init__(self, val: float):
method __call__ (line 146) | def __call__(self, _: float) -> float:
method __repr__ (line 149) | def __repr__(self) -> str:
function get_schedule_fn (line 158) | def get_schedule_fn(value_schedule: Schedule | float) -> Schedule:
function get_linear_fn (line 179) | def get_linear_fn(start: float, end: float, end_fraction: float) -> Sche...
function constant_fn (line 204) | def constant_fn(val: float) -> Schedule:
function get_device (line 223) | def get_device(device: th.device | str = "auto") -> th.device:
function get_latest_run_id (line 246) | def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
function configure_logger (line 265) | def configure_logger(
function check_for_correct_spaces (line 302) | def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Spac...
function check_shape_equal (line 320) | def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
function is_vectorized_box_observation (line 340) | def is_vectorized_box_observation(observation: np.ndarray, observation_s...
function is_vectorized_discrete_observation (line 361) | def is_vectorized_discrete_observation(observation: int | np.ndarray, ob...
function is_vectorized_multidiscrete_observation (line 381) | def is_vectorized_multidiscrete_observation(observation: np.ndarray, obs...
function is_vectorized_multibinary_observation (line 402) | def is_vectorized_multibinary_observation(observation: np.ndarray, obser...
function is_vectorized_dict_observation (line 423) | def is_vectorized_dict_observation(observation: np.ndarray, observation_...
function is_vectorized_observation (line 467) | def is_vectorized_observation(observation: int | np.ndarray, observation...
function safe_mean (line 493) | def safe_mean(arr: np.ndarray | list | deque) -> float:
function get_parameters_by_name (line 504) | def get_parameters_by_name(model: th.nn.Module, included_names: Iterable...
function zip_strict (line 517) | def zip_strict(*iterables: Iterable) -> Iterable:
function polyak_update (line 530) | def polyak_update(
function obs_as_tensor (line 556) | def obs_as_tensor(obs: np.ndarray | dict[str, np.ndarray], device: th.de...
function should_collect_more_steps (line 572) | def should_collect_more_steps(
function get_system_info (line 600) | def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
FILE: stable_baselines3/common/vec_env/__init__.py
function unwrap_vec_wrapper (line 19) | def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrappe...
function unwrap_vec_normalize (line 35) | def unwrap_vec_normalize(env: VecEnv) -> VecNormalize | None:
function is_vecenv_wrapped (line 45) | def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper...
function sync_envs_normalization (line 56) | def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
FILE: stable_baselines3/common/vec_env/base_vec_env.py
function tile_images (line 24) | def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pra...
class VecEnv (line 50) | class VecEnv(ABC):
method __init__ (line 59) | def __init__(
method _reset_seeds (line 96) | def _reset_seeds(self) -> None:
method _reset_options (line 102) | def _reset_options(self) -> None:
method reset (line 109) | def reset(self) -> VecEnvObs:
method step_async (line 123) | def step_async(self, actions: np.ndarray) -> None:
method step_wait (line 135) | def step_wait(self) -> VecEnvStepReturn:
method close (line 144) | def close(self) -> None:
method has_attr (line 150) | def has_attr(self, attr_name: str) -> bool:
method get_attr (line 166) | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> l...
method set_attr (line 177) | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices ...
method env_method (line 189) | def env_method(self, method_name: str, *method_args, indices: VecEnvIn...
method env_is_wrapped (line 202) | def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: Ve...
method step (line 214) | def step(self, actions: np.ndarray) -> VecEnvStepReturn:
method get_images (line 224) | def get_images(self) -> Sequence[np.ndarray | None]:
method render (line 230) | def render(self, mode: str | None = None) -> np.ndarray | None:
method seed (line 292) | def seed(self, seed: int | None = None) -> Sequence[None | int]:
method set_options (line 311) | def set_options(self, options: list[dict] | dict | None = None) -> None:
method unwrapped (line 328) | def unwrapped(self) -> "VecEnv":
method getattr_depth_check (line 334) | def getattr_depth_check(self, name: str, already_found: bool) -> str |...
method _get_indices (line 346) | def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]:
class VecEnvWrapper (line 360) | class VecEnvWrapper(VecEnv):
method __init__ (line 369) | def __init__(
method step_async (line 384) | def step_async(self, actions: np.ndarray) -> None:
method reset (line 388) | def reset(self) -> VecEnvObs:
method step_wait (line 392) | def step_wait(self) -> VecEnvStepReturn:
method seed (line 395) | def seed(self, seed: int | None = None) -> Sequence[None | int]:
method set_options (line 398) | def set_options(self, options: list[dict] | dict | None = None) -> None:
method close (line 401) | def close(self) -> None:
method render (line 404) | def render(self, mode: str | None = None) -> np.ndarray | None:
method get_images (line 407) | def get_images(self) -> Sequence[np.ndarray | None]:
method has_attr (line 410) | def has_attr(self, attr_name: str) -> bool:
method get_attr (line 413) | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> l...
method set_attr (line 416) | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices ...
method env_method (line 419) | def env_method(self, method_name: str, *method_args, indices: VecEnvIn...
method env_is_wrapped (line 422) | def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: Ve...
method __getattr__ (line 425) | def __getattr__(self, name: str) -> Any:
method _get_all_attributes (line 441) | def _get_all_attributes(self) -> dict[str, Any]:
method getattr_recursive (line 450) | def getattr_recursive(self, name: str) -> Any:
method getattr_depth_check (line 468) | def getattr_depth_check(self, name: str, already_found: bool) -> str |...
class CloudpickleWrapper (line 487) | class CloudpickleWrapper:
method __init__ (line 494) | def __init__(self, var: Any):
method __getstate__ (line 497) | def __getstate__(self) -> Any:
method __setstate__ (line 500) | def __setstate__(self, var: Any) -> None:
FILE: stable_baselines3/common/vec_env/dummy_vec_env.py
class DummyVecEnv (line 15) | class DummyVecEnv(VecEnv):
method __init__ (line 30) | def __init__(self, env_fns: list[Callable[[], gym.Env]]):
method step_async (line 53) | def step_async(self, actions: np.ndarray) -> None:
method step_wait (line 56) | def step_wait(self) -> VecEnvStepReturn:
method reset (line 75) | def reset(self) -> VecEnvObs:
method close (line 85) | def close(self) -> None:
method get_images (line 89) | def get_images(self) -> Sequence[np.ndarray | None]:
method render (line 97) | def render(self, mode: str | None = None) -> np.ndarray | None:
method _save_obs (line 106) | def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
method _obs_from_buf (line 113) | def _obs_from_buf(self) -> VecEnvObs:
method get_attr (line 116) | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> l...
method set_attr (line 121) | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices ...
method env_method (line 127) | def env_method(self, method_name: str, *method_args, indices: VecEnvIn...
method env_is_wrapped (line 132) | def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: Ve...
method _get_target_envs (line 140) | def _get_target_envs(self, indices: VecEnvIndices) -> list[gym.Env]:
FILE: stable_baselines3/common/vec_env/patch_gym.py
function _patch_env (line 15) | def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: ...
function _convert_space (line 63) | def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnas...
FILE: stable_baselines3/common/vec_env/stacked_observations.py
class StackedObservations (line 13) | class StackedObservations(Generic[TObs]):
method __init__ (line 28) | def __init__(
method compute_stacking (line 68) | def compute_stacking(
method reset (line 102) | def reset(self, observation: TObs) -> TObs:
method update (line 119) | def update(
FILE: stable_baselines3/common/vec_env/subproc_vec_env.py
function _worker (line 20) | def _worker( # noqa: C901
class SubprocVecEnv (line 79) | class SubprocVecEnv(VecEnv):
method __init__ (line 103) | def __init__(self, env_fns: list[Callable[[], gym.Env]], start_method:...
method step_async (line 131) | def step_async(self, actions: np.ndarray) -> None:
method step_wait (line 136) | def step_wait(self) -> VecEnvStepReturn:
method reset (line 142) | def reset(self) -> VecEnvObs:
method close (line 152) | def close(self) -> None:
method get_images (line 164) | def get_images(self) -> Sequence[np.ndarray | None]:
method has_attr (line 176) | def has_attr(self, attr_name: str) -> bool:
method get_attr (line 183) | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> l...
method set_attr (line 190) | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices ...
method env_method (line 198) | def env_method(self, method_name: str, *method_args, indices: VecEnvIn...
method env_is_wrapped (line 205) | def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: Ve...
method _get_target_remotes (line 212) | def _get_target_remotes(self, indices: VecEnvIndices) -> list[Any]:
function _stack_obs (line 224) | def _stack_obs(obs_list: list[VecEnvObs] | tuple[VecEnvObs], space: spac...
FILE: stable_baselines3/common/vec_env/util.py
function dict_to_obs (line 14) | def dict_to_obs(obs_space: spaces.Space, obs_dict: dict[Any, np.ndarray]...
function obs_space_info (line 35) | def obs_space_info(obs_space: spaces.Space) -> tuple[list[str], dict[Any...
FILE: stable_baselines3/common/vec_env/vec_check_nan.py
class VecCheckNan (line 9) | class VecCheckNan(VecEnvWrapper):
method __init__ (line 20) | def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_o...
method step_async (line 33) | def step_async(self, actions: np.ndarray) -> None:
method step_wait (line 38) | def step_wait(self) -> VecEnvStepReturn:
method reset (line 44) | def reset(self) -> VecEnvObs:
method check_array_value (line 50) | def check_array_value(self, name: str, value: np.ndarray) -> list[tupl...
method _check_val (line 67) | def _check_val(self, event: str, **kwargs) -> None:
FILE: stable_baselines3/common/vec_env/vec_extract_dict_obs.py
class VecExtractDictObs (line 7) | class VecExtractDictObs(VecEnvWrapper):
method __init__ (line 15) | def __init__(self, venv: VecEnv, key: str):
method reset (line 22) | def reset(self) -> np.ndarray:
method step_wait (line 27) | def step_wait(self) -> VecEnvStepReturn:
FILE: stable_baselines3/common/vec_env/vec_frame_stack.py
class VecFrameStack (line 11) | class VecFrameStack(VecEnvWrapper):
method __init__ (line 22) | def __init__(self, venv: VecEnv, n_stack: int, channels_order: str | M...
method step_wait (line 31) | def step_wait(
method reset (line 43) | def reset(self) -> np.ndarray | dict[str, np.ndarray]:
FILE: stable_baselines3/common/vec_env/vec_monitor.py
class VecMonitor (line 9) | class VecMonitor(VecEnvWrapper):
method __init__ (line 25) | def __init__(
method reset (line 68) | def reset(self) -> VecEnvObs:
method step_wait (line 74) | def step_wait(self) -> VecEnvStepReturn:
method close (line 96) | def close(self) -> None:
FILE: stable_baselines3/common/vec_env/vec_normalize.py
class VecNormalize (line 15) | class VecNormalize(VecEnvWrapper):
method __init__ (line 35) | def __init__(
method _sanity_checks (line 100) | def _sanity_checks(self) -> None:
method __getstate__ (line 128) | def __getstate__(self) -> dict[str, Any]:
method __setstate__ (line 141) | def __setstate__(self, state: dict[str, Any]) -> None:
method set_venv (line 155) | def set_venv(self, venv: VecEnv) -> None:
method step_wait (line 174) | def step_wait(self) -> VecEnvStepReturn:
method _update_reward (line 209) | def _update_reward(self, reward: np.ndarray) -> None:
method _normalize_obs (line 214) | def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> ...
method _unnormalize_obs (line 223) | def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -...
method normalize_obs (line 232) | def normalize_obs(self, obs: np.ndarray | dict[str, np.ndarray]) -> np...
method normalize_reward (line 250) | def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
method unnormalize_obs (line 261) | def unnormalize_obs(self, obs: np.ndarray | dict[str, np.ndarray]) -> ...
method unnormalize_reward (line 274) | def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
method get_original_obs (line 279) | def get_original_obs(self) -> np.ndarray | dict[str, np.ndarray]:
method get_original_reward (line 286) | def get_original_reward(self) -> np.ndarray:
method reset (line 292) | def reset(self) -> np.ndarray | dict[str, np.ndarray]:
method load (line 311) | def load(load_path: str, venv: VecEnv) -> "VecNormalize":
method save (line 324) | def save(self, save_path: str) -> None:
FILE: stable_baselines3/common/vec_env/vec_transpose.py
class VecTransposeImage (line 10) | class VecTransposeImage(VecEnvWrapper):
method __init__ (line 20) | def __init__(self, venv: VecEnv, skip: bool = False):
method transpose_space (line 46) | def transpose_space(observation_space: spaces.Box, key: str = "") -> s...
method transpose_image (line 64) | def transpose_image(image: np.ndarray) -> np.ndarray:
method transpose_observations (line 75) | def transpose_observations(self, observations: np.ndarray | dict) -> n...
method step_wait (line 95) | def step_wait(self) -> VecEnvStepReturn:
method reset (line 108) | def reset(self) -> np.ndarray | dict:
method close (line 116) | def close(self) -> None:
FILE: stable_baselines3/common/vec_env/vec_video_recorder.py
class VecVideoRecorder (line 13) | class VecVideoRecorder(VecEnvWrapper):
method __init__ (line 35) | def __init__(
method reset (line 81) | def reset(self) -> VecEnvObs:
method _start_video_recorder (line 87) | def _start_video_recorder(self) -> None:
method _video_enabled (line 94) | def _video_enabled(self) -> bool:
method step_wait (line 97) | def step_wait(self) -> VecEnvStepReturn:
method _capture_frame (line 111) | def _capture_frame(self) -> None:
method close (line 124) | def close(self) -> None:
method _start_recording (line 130) | def _start_recording(self) -> None:
method _stop_recording (line 137) | def _stop_recording(self) -> None:
method __del__ (line 154) | def __del__(self) -> None:
FILE: stable_baselines3/ddpg/ddpg.py
class DDPG (line 14) | class DDPG(TD3):
method __init__ (line 57) | def __init__(
method learn (line 117) | def learn(
FILE: stable_baselines3/dqn/dqn.py
class DQN (line 19) | class DQN(OffPolicyAlgorithm):
method __init__ (line 77) | def __init__(
method _setup_model (line 146) | def _setup_model(self) -> None:
method _create_aliases (line 167) | def _create_aliases(self) -> None:
method _on_step (line 171) | def _on_step(self) -> None:
method train (line 187) | def train(self, gradient_steps: int, batch_size: int = 100) -> None:
method predict (line 233) | def predict(
method learn (line 263) | def learn(
method _excluded_save_params (line 281) | def _excluded_save_params(self) -> list[str]:
method _get_torch_save_params (line 284) | def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
FILE: stable_baselines3/dqn/policies.py
class QNetwork (line 18) | class QNetwork(BasePolicy):
method __init__ (line 32) | def __init__(
method forward (line 59) | def forward(self, obs: PyTorchObs) -> th.Tensor:
method _predict (line 68) | def _predict(self, observation: PyTorchObs, deterministic: bool = True...
method _get_constructor_parameters (line 74) | def _get_constructor_parameters(self) -> dict[str, Any]:
class DQNPolicy (line 88) | class DQNPolicy(BasePolicy):
method __init__ (line 111) | def __init__(
method _build (line 153) | def _build(self, lr_schedule: Schedule) -> None:
method make_q_net (line 175) | def make_q_net(self) -> QNetwork:
method forward (line 180) | def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.T...
method _predict (line 183) | def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th....
method _get_constructor_parameters (line 186) | def _get_constructor_parameters(self) -> dict[str, Any]:
method set_training_mode (line 202) | def set_training_mode(self, mode: bool) -> None:
class CnnPolicy (line 217) | class CnnPolicy(DQNPolicy):
method __init__ (line 235) | def __init__(
class MultiInputPolicy (line 262) | class MultiInputPolicy(DQNPolicy):
method __init__ (line 280) | def __init__(
FILE: stable_baselines3/her/goal_selection_strategy.py
class GoalSelectionStrategy (line 4) | class GoalSelectionStrategy(Enum):
FILE: stable_baselines3/her/her_replay_buffer.py
class HerReplayBuffer (line 15) | class HerReplayBuffer(DictReplayBuffer):
method __init__ (line 50) | def __init__(
method __getstate__ (line 101) | def __getstate__(self) -> dict[str, Any]:
method __setstate__ (line 112) | def __setstate__(self, state: dict[str, Any]) -> None:
method set_env (line 124) | def set_env(self, env: VecEnv) -> None:
method add (line 135) | def add( # type: ignore[override]
method _compute_episode_length (line 169) | def _compute_episode_length(self, env_idx: int) -> None:
method sample (line 186) | def sample(self, batch_size: int, env: VecNormalize | None = None) -> ...
method _get_real_samples (line 248) | def _get_real_samples(
method _get_virtual_samples (line 287) | def _get_virtual_samples(
method _sample_goals (line 355) | def _sample_goals(self, batch_indices: np.ndarray, env_indices: np.nda...
method truncate_last_trajectory (line 386) | def truncate_last_trajectory(self) -> None:
FILE: stable_baselines3/ppo/ppo.py
class PPO (line 18) | class PPO(OnPolicyAlgorithm):
method __init__ (line 80) | def __init__(
method _setup_model (line 173) | def _setup_model(self) -> None:
method train (line 184) | def train(self) -> None:
method learn (line 302) | def learn(
FILE: stable_baselines3/sac/policies.py
class Actor (line 25) | class Actor(BasePolicy):
method __init__ (line 50) | def __init__(
method _get_constructor_parameters (line 105) | def _get_constructor_parameters(self) -> dict[str, Any]:
method get_std (line 123) | def get_std(self) -> th.Tensor:
method reset_noise (line 137) | def reset_noise(self, batch_size: int = 1) -> None:
method get_action_dist_params (line 147) | def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, ...
method forward (line 167) | def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th....
method action_log_prob (line 172) | def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tens...
method _predict (line 177) | def _predict(self, observation: PyTorchObs, deterministic: bool = Fals...
class SACPolicy (line 181) | class SACPolicy(BasePolicy):
method __init__ (line 214) | def __init__(
method _build (line 280) | def _build(self, lr_schedule: Schedule) -> None:
method _get_constructor_parameters (line 312) | def _get_constructor_parameters(self) -> dict[str, Any]:
method reset_noise (line 333) | def reset_noise(self, batch_size: int = 1) -> None:
method make_actor (line 341) | def make_actor(self, features_extractor: BaseFeaturesExtractor | None ...
method make_critic (line 345) | def make_critic(self, features_extractor: BaseFeaturesExtractor | None...
method forward (line 349) | def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th....
method _predict (line 352) | def _predict(self, observation: PyTorchObs, deterministic: bool = Fals...
method set_training_mode (line 355) | def set_training_mode(self, mode: bool) -> None:
class CnnPolicy (line 371) | class CnnPolicy(SACPolicy):
method __init__ (line 398) | def __init__(
class MultiInputPolicy (line 437) | class MultiInputPolicy(SACPolicy):
method __init__ (line 464) | def __init__(
FILE: stable_baselines3/sac/sac.py
class SAC (line 19) | class SAC(OffPolicyAlgorithm):
method __init__ (line 91) | def __init__(
method _setup_model (line 162) | def _setup_model(self) -> None:
method _create_aliases (line 197) | def _create_aliases(self) -> None:
method train (line 202) | def train(self, gradient_steps: int, batch_size: int = 64) -> None:
method learn (line 304) | def learn(
method _excluded_save_params (line 322) | def _excluded_save_params(self) -> list[str]:
method _get_torch_save_params (line 325) | def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
FILE: stable_baselines3/td3/policies.py
class Actor (line 20) | class Actor(BasePolicy):
method __init__ (line 35) | def __init__(
method _get_constructor_parameters (line 62) | def _get_constructor_parameters(self) -> dict[str, Any]:
method forward (line 75) | def forward(self, obs: th.Tensor) -> th.Tensor:
method _predict (line 80) | def _predict(self, observation: PyTorchObs, deterministic: bool = Fals...
class TD3Policy (line 86) | class TD3Policy(BasePolicy):
method __init__ (line 114) | def __init__(
method _build (line 172) | def _build(self, lr_schedule: Schedule) -> None:
method _get_constructor_parameters (line 210) | def _get_constructor_parameters(self) -> dict[str, Any]:
method make_actor (line 228) | def make_actor(self, features_extractor: BaseFeaturesExtractor | None ...
method make_critic (line 232) | def make_critic(self, features_extractor: BaseFeaturesExtractor | None...
method forward (line 236) | def forward(self, observation: PyTorchObs, deterministic: bool = False...
method _predict (line 239) | def _predict(self, observation: PyTorchObs, deterministic: bool = Fals...
method set_training_mode (line 244) | def set_training_mode(self, mode: bool) -> None:
class CnnPolicy (line 260) | class CnnPolicy(TD3Policy):
method __init__ (line 283) | def __init__(
class MultiInputPolicy (line 314) | class MultiInputPolicy(TD3Policy):
method __init__ (line 337) | def __init__(
FILE: stable_baselines3/td3/td3.py
class TD3 (line 19) | class TD3(OffPolicyAlgorithm):
method __init__ (line 80) | def __init__(
method _setup_model (line 142) | def _setup_model(self) -> None:
method _create_aliases (line 151) | def _create_aliases(self) -> None:
method train (line 157) | def train(self, gradient_steps: int, batch_size: int = 100) -> None:
method learn (line 218) | def learn(
method _excluded_save_params (line 236) | def _excluded_save_params(self) -> list[str]:
method _get_torch_save_params (line 239) | def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
FILE: tests/test_buffers.py
class DummyEnv (line 16) | class DummyEnv(gym.Env):
method __init__ (line 21) | def __init__(self):
method reset (line 29) | def reset(self, *, seed=None, options=None):
method step (line 34) | def step(self, action):
class DummyDictEnv (line 44) | class DummyDictEnv(gym.Env):
method __init__ (line 49) | def __init__(self):
method reset (line 59) | def reset(self, seed=None, options=None):
method step (line 64) | def step(self, action):
function test_env (line 75) | def test_env(env_cls):
function test_replay_buffer_normalization (line 82) | def test_replay_buffer_normalization(replay_buffer_cls):
function test_device_buffer (line 114) | def test_device_buffer(replay_buffer_cls, device):
function test_buffer_dtypes (line 191) | def test_buffer_dtypes(obs_dtype, use_dict, action_space):
function test_custom_rollout_buffer (line 235) | def test_custom_rollout_buffer():
FILE: tests/test_callbacks.py
function select_env (line 26) | def select_env(model_class) -> str:
function test_callbacks (line 34) | def test_callbacks(tmp_path, model_class):
function test_eval_callback_vec_env (line 114) | def test_eval_callback_vec_env():
class AlwaysFailCallback (line 130) | class AlwaysFailCallback(BaseCallback):
method __init__ (line 131) | def __init__(self, *args, callback_false_value, **kwargs):
method _on_step (line 135) | def _on_step(self) -> bool:
function test_callbacks_can_cancel_runs (line 154) | def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_f...
function test_eval_success_logging (line 164) | def test_eval_success_logging(tmp_path):
function test_eval_callback_logs_are_written_with_the_correct_timestep (line 188) | def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_pa...
function test_eval_friendly_error (line 214) | def test_eval_friendly_error():
function test_checkpoint_additional_info (line 245) | def test_checkpoint_additional_info(tmp_path):
function test_eval_callback_chaining (line 271) | def test_eval_callback_chaining(tmp_path):
FILE: tests/test_cnn.py
function test_cnn (line 17) | def test_cnn(tmp_path, model_class, share_features_extractor):
function test_vec_transpose_skip (line 70) | def test_vec_transpose_skip(tmp_path, model_class):
function patch_dqn_names_ (line 96) | def patch_dqn_names_(model):
function params_should_match (line 103) | def params_should_match(params, other_params):
function params_should_differ (line 108) | def params_should_differ(params, other_params):
function check_td3_feature_extractor_match (line 113) | def check_td3_feature_extractor_match(model):
function check_td3_feature_extractor_differ (line 121) | def check_td3_feature_extractor_differ(model):
function test_features_extractor_target_net (line 131) | def test_features_extractor_target_net(model_class, share_features_extra...
function test_channel_first_env (line 240) | def test_channel_first_env(tmp_path):
function test_image_space_checks (line 269) | def test_image_space_checks():
function test_image_like_input (line 321) | def test_image_like_input(model_class, normalize_images):
FILE: tests/test_custom_policy.py
function test_flexible_mlp (line 26) | def test_flexible_mlp(model_class, net_arch):
function test_custom_offpolicy (line 36) | def test_custom_offpolicy(model_class, net_arch):
function test_custom_optimizer (line 42) | def test_custom_optimizer(model_class, optimizer_kwargs):
function test_tf_like_rmsprop_optimizer (line 59) | def test_tf_like_rmsprop_optimizer():
function test_dqn_custom_policy (line 64) | def test_dqn_custom_policy():
function test_create_mlp (line 69) | def test_create_mlp():
FILE: tests/test_deterministic.py
function test_deterministic_training_common (line 12) | def test_deterministic_training_common(algo):
FILE: tests/test_dict_env.py
class DummyDictEnv (line 14) | class DummyDictEnv(gym.Env):
method __init__ (line 19) | def __init__(
method seed (line 64) | def seed(self, seed=None):
method step (line 68) | def step(self, action):
method reset (line 73) | def reset(self, *, seed: int | None = None, options: dict | None = None):
method render (line 78) | def render(self):
function test_env (line 86) | def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_on...
function test_policy_hint (line 96) | def test_policy_hint(policy):
function test_goal_env (line 103) | def test_goal_env(model_class):
function test_consistency (line 111) | def test_consistency(model_class):
function test_dict_spaces (line 157) | def test_dict_spaces(model_class, channel_last):
function test_multiprocessing (line 200) | def test_multiprocessing(model_class):
function test_dict_vec_framestack (line 238) | def test_dict_vec_framestack(model_class, channel_last):
function test_vec_normalize (line 285) | def test_vec_normalize(model_class):
function test_dict_nested (line 324) | def test_dict_nested():
function test_vec_normalize_image (line 340) | def test_vec_normalize_image():
FILE: tests/test_distributions.py
function test_bijector (line 26) | def test_bijector():
function test_squashed_gaussian (line 41) | def test_squashed_gaussian(model_class):
function dummy_model_distribution_obs_and_actions (line 57) | def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray,...
function test_get_distribution (line 69) | def test_get_distribution(dummy_model_distribution_obs_and_actions):
function test_predict_values (line 85) | def test_predict_values(dummy_model_distribution_obs_and_actions):
function test_sde_distribution (line 96) | def test_sde_distribution():
function test_entropy (line 121) | def test_entropy(dist):
function test_categorical (line 149) | def test_categorical(dist, CAT_ACTIONS):
function test_kl_divergence (line 174) | def test_kl_divergence(dist_type):
FILE: tests/test_env_checker.py
class ActionDictTestEnv (line 11) | class ActionDictTestEnv(gym.Env):
method step (line 18) | def step(self, action):
method reset (line 26) | def reset(self, *, seed=None, options=None):
method render (line 29) | def render(self):
function test_check_env_dict_action (line 33) | def test_check_env_dict_action():
class CustomEnv (line 40) | class CustomEnv(gym.Env):
method __init__ (line 43) | def __init__(self, render_mode=None):
method reset (line 48) | def reset(self, *, seed=None, options=None):
method step (line 52) | def step(self, action):
function test_check_env_detailed_error (line 121) | def test_check_env_detailed_error(obs_tuple, method):
class LimitedStepsTestEnv (line 146) | class LimitedStepsTestEnv(gym.Env):
method __init__ (line 150) | def __init__(self, steps_before_termination: int = 1):
method reset (line 159) | def reset(self, *, seed: int | None = None, options: dict | None = Non...
method step (line 167) | def step(self, action: np.ndarray) -> tuple[int, float, bool, bool, di...
method render (line 179) | def render(self) -> None:
function test_check_env_single_step_env (line 183) | def test_check_env_single_step_env():
class SimpleGraphEnv (line 190) | class SimpleGraphEnv(CustomEnv):
method __init__ (line 191) | def __init__(self):
class SimpleDictGraphEnv (line 199) | class SimpleDictGraphEnv(CustomEnv):
method __init__ (line 200) | def __init__(self):
function test_check_env_graph_space (line 212) | def test_check_env_graph_space():
class SequenceInDictEnv (line 221) | class SequenceInDictEnv(CustomEnv):
method __init__ (line 224) | def __init__(self):
class SequenceInTupleEnv (line 231) | class SequenceInTupleEnv(CustomEnv):
method __init__ (line 234) | def __init__(self):
class SequenceInOneOfEnv (line 239) | class SequenceInOneOfEnv(CustomEnv):
method __init__ (line 242) | def __init__(self):
function test_check_env_sequence_obs (line 253) | def test_check_env_sequence_obs(env_class):
function test_check_env_sequence_tuple (line 258) | def test_check_env_sequence_tuple():
function test_check_env_oneof (line 266) | def test_check_env_oneof():
FILE: tests/test_envs.py
function test_env (line 32) | def test_env(env_id):
function test_custom_envs (line 52) | def test_custom_envs(env_class):
function test_bit_flipping (line 69) | def test_bit_flipping(kwargs):
function test_high_dimension_action_space (line 90) | def test_high_dimension_action_space():
function test_non_default_spaces (line 134) | def test_non_default_spaces(new_obs_space):
function test_non_default_action_spaces (line 177) | def test_non_default_action_spaces(new_action_space):
function check_reset_assert_error (line 210) | def check_reset_assert_error(env, new_reset_return):
function test_common_failures_reset (line 226) | def test_common_failures_reset():
function check_step_assert_error (line 277) | def check_step_assert_error(env, new_step_return=()):
function test_common_failures_step (line 293) | def test_common_failures_step():
FILE: tests/test_gae.py
class CustomEnv (line 13) | class CustomEnv(gym.Env):
method __init__ (line 14) | def __init__(self, max_steps=8):
method seed (line 21) | def seed(self, seed):
method reset (line 24) | def reset(self, *, seed: int | None = None, options: dict | None = None):
method step (line 30) | def step(self, action):
class InfiniteHorizonEnv (line 46) | class InfiniteHorizonEnv(gym.Env):
method __init__ (line 47) | def __init__(self, n_states=4):
method reset (line 54) | def reset(self, *, seed: int | None = None, options: dict | None = None):
method step (line 61) | def step(self, action):
class CheckGAECallback (line 66) | class CheckGAECallback(BaseCallback):
method __init__ (line 67) | def __init__(self):
method _on_rollout_end (line 70) | def _on_rollout_end(self):
method _on_step (line 105) | def _on_step(self):
class CustomPolicy (line 109) | class CustomPolicy(ActorCriticPolicy):
method __init__ (line 112) | def __init__(self, *args, **kwargs):
method forward (line 116) | def forward(self, obs, deterministic=False):
function test_env (line 124) | def test_env(env_cls):
function test_gae_computation (line 133) | def test_gae_computation(model_class, gae_lambda, gamma, num_episodes):
function test_infinite_horizon (line 153) | def test_infinite_horizon(model_class, handle_timeout_termination):
FILE: tests/test_her.py
function test_import_error (line 20) | def test_import_error():
function test_her (line 30) | def test_her(model_class, image_obs_space):
function test_multiprocessing (line 68) | def test_multiprocessing(model_class, image_obs_space):
function test_goal_selection_strategy (line 88) | def test_goal_selection_strategy(goal_selection_strategy):
function test_save_load (line 122) | def test_save_load(tmp_path, model_class, use_sde):
function test_save_load_replay_buffer (line 227) | def test_save_load_replay_buffer(n_envs, tmp_path, recwarn, truncate_las...
function test_full_replay_buffer (line 292) | def test_full_replay_buffer():
function test_truncate_last_trajectory (line 329) | def test_truncate_last_trajectory(n_envs, recwarn, n_steps, handle_timeo...
function test_performance_her (line 435) | def test_performance_her(n_bits):
FILE: tests/test_identity.py
function test_discrete (line 15) | def test_discrete(model_class, env):
function test_continuous (line 34) | def test_continuous(model_class):
FILE: tests/test_logger.py
class LogContent (line 55) | class LogContent:
method __init__ (line 60) | def __init__(self, _format: str, lines: Sequence):
method empty (line 65) | def empty(self):
method __repr__ (line 68) | def __repr__(self):
function read_log (line 73) | def read_log(tmp_path, capsys):
function test_set_logger (line 109) | def test_set_logger(tmp_path):
function test_main (line 154) | def test_main(tmp_path):
function test_make_output (line 192) | def test_make_output(tmp_path, read_log, _format):
function test_make_output_fail (line 208) | def test_make_output_fail(tmp_path):
function test_exclude_keys (line 218) | def test_exclude_keys(tmp_path, read_log, _format):
function test_report_video_to_tensorboard (line 229) | def test_report_video_to_tensorboard(tmp_path, read_log, capsys):
function is_moviepy_installed (line 250) | def is_moviepy_installed():
function test_unsupported_video_format (line 255) | def test_unsupported_video_format(tmp_path, unsupported_format):
function test_log_histogram (line 274) | def test_log_histogram(tmp_path, read_log, histogram):
function test_unsupported_type_histogram (line 299) | def test_unsupported_type_histogram(tmp_path, read_log, histogram):
function test_report_image_to_tensorboard (line 313) | def test_report_image_to_tensorboard(tmp_path, read_log):
function test_unsupported_image_format (line 325) | def test_unsupported_image_format(tmp_path, unsupported_format):
function test_report_figure_to_tensorboard (line 335) | def test_report_figure_to_tensorboard(tmp_path, read_log):
function test_unsupported_figure_format (line 349) | def test_unsupported_figure_format(tmp_path, unsupported_format):
function test_unsupported_hparam (line 362) | def test_unsupported_hparam(tmp_path, unsupported_format):
function test_key_length (line 374) | def test_key_length(tmp_path):
class TimeDelayEnv (line 405) | class TimeDelayEnv(gym.Env):
method __init__ (line 410) | def __init__(self, delay: float = 0.01):
method reset (line 416) | def reset(self, seed=None):
method step (line 419) | def step(self, action):
function test_env (line 426) | def test_env(env_cls):
class InMemoryLogger (line 431) | class InMemoryLogger(Logger):
method __init__ (line 436) | def __init__(self):
method dump (line 439) | def dump(self, step: int = 0) -> None:
function test_fps_logger (line 444) | def test_fps_logger(tmp_path, algo):
function test_fps_no_div_zero (line 469) | def test_fps_no_div_zero(algo):
function test_human_output_same_keys_different_tags (line 481) | def test_human_output_same_keys_different_tags():
function test_ep_buffers_stats_window_size (line 491) | def test_ep_buffers_stats_window_size(algo, stats_window_size):
function test_human_out_custom_text_io (line 501) | def test_human_out_custom_text_io(base_class):
class DummySuccessEnv (line 539) | class DummySuccessEnv(gym.Env):
method __init__ (line 545) | def __init__(self, dummy_successes, ep_steps):
method reset (line 566) | def reset(self, seed=None, options=None):
method step (line 578) | def step(self, action):
function test_rollout_success_rate_onpolicy_algo (line 593) | def test_rollout_success_rate_onpolicy_algo(tmp_path):
FILE: tests/test_monitor.py
function test_monitor (line 20) | def test_monitor(tmp_path):
function test_monitor_load_results (line 70) | def test_monitor_load_results(tmp_path):
FILE: tests/test_n_step_replay.py
function test_run (line 11) | def test_run(model_class):
function create_buffer (line 33) | def create_buffer(buffer_size=10, n_steps=3, gamma=0.99, n_envs=1):
function create_normal_buffer (line 47) | def create_normal_buffer(buffer_size=10, n_envs=1):
function fill_buffer (line 59) | def fill_buffer(buffer, length, done_at=None, truncated_at=None):
function compute_expected_nstep_reward (line 78) | def compute_expected_nstep_reward(gamma, n_steps, stop_idx=None):
function test_nstep_early_termination (line 96) | def test_nstep_early_termination(done_at, n_steps, base_idx):
function test_nstep_early_truncation (line 111) | def test_nstep_early_truncation(truncated_at, n_steps, base_idx):
function test_nstep_no_terminations (line 124) | def test_nstep_no_terminations(n_steps):
function test_match_normal_buffer (line 161) | def test_match_normal_buffer():
FILE: tests/test_predict.py
class SubClassedBox (line 22) | class SubClassedBox(spaces.Box):
method __init__ (line 23) | def __init__(self, *args, **kwargs):
class CustomSubClassedSpaceEnv (line 27) | class CustomSubClassedSpaceEnv(gym.Env):
method __init__ (line 28) | def __init__(self):
method reset (line 33) | def reset(self, seed=None):
method step (line 36) | def step(self, action):
function test_env (line 41) | def test_env(env_cls):
function test_auto_wrap (line 47) | def test_auto_wrap(model_class):
function test_predict (line 62) | def test_predict(model_class, env_id, device):
function test_dqn_epsilon_greedy (line 102) | def test_dqn_epsilon_greedy():
function test_subclassed_space_env (line 113) | def test_subclassed_space_env(model_class):
function test_mixing_gym_vecenv_api (line 121) | def test_mixing_gym_vecenv_api():
FILE: tests/test_preprocessing.py
function test_get_obs_shape_discrete (line 7) | def test_get_obs_shape_discrete():
function test_get_obs_shape_multidiscrete (line 11) | def test_get_obs_shape_multidiscrete():
function test_get_obs_shape_multibinary (line 15) | def test_get_obs_shape_multibinary():
function test_get_obs_shape_multidimensional_multibinary (line 19) | def test_get_obs_shape_multidimensional_multibinary():
function test_get_obs_shape_box (line 23) | def test_get_obs_shape_box():
function test_get_obs_shape_multidimensional_box (line 27) | def test_get_obs_shape_multidimensional_box():
function test_preprocess_obs_discrete (line 31) | def test_preprocess_obs_discrete():
function test_preprocess_obs_multidiscrete (line 37) | def test_preprocess_obs_multidiscrete():
function test_preprocess_obs_multibinary (line 43) | def test_preprocess_obs_multibinary():
function test_preprocess_obs_multidimensional_multibinary (line 49) | def test_preprocess_obs_multidimensional_multibinary():
function test_preprocess_obs_box (line 55) | def test_preprocess_obs_box():
function test_preprocess_obs_multidimensional_box (line 61) | def test_preprocess_obs_multidimensional_box():
FILE: tests/test_run.py
function test_deterministic_pg (line 18) | def test_deterministic_pg(model_class, action_noise):
function test_a2c (line 35) | def test_a2c(env_id):
function test_advantage_normalization (line 42) | def test_advantage_normalization(model_class, normalize_advantage):
function test_ppo (line 49) | def test_ppo(env_id, clip_range_vf):
function test_sac (line 76) | def test_sac(ent_coef):
function test_n_critics (line 91) | def test_n_critics(n_critics):
function test_dqn (line 104) | def test_dqn():
function test_train_freq (line 118) | def test_train_freq(tmp_path, train_freq):
function test_train_freq_fail (line 138) | def test_train_freq_fail(train_freq):
function test_offpolicy_multi_env (line 153) | def test_offpolicy_multi_env(model_class):
function test_warn_dqn_multi_env (line 204) | def test_warn_dqn_multi_env():
function test_ppo_warnings (line 214) | def test_ppo_warnings():
FILE: tests/test_save_load.py
function select_env (line 28) | def select_env(model_class: BaseAlgorithm) -> gym.Env:
function test_save_load (line 39) | def test_save_load(tmp_path, model_class):
function test_set_env (line 181) | def test_set_env(tmp_path, model_class):
function test_exclude_include_saved_params (line 256) | def test_exclude_include_saved_params(tmp_path, model_class):
function test_save_load_pytorch_var (line 296) | def test_save_load_pytorch_var(tmp_path):
function test_save_load_env_cnn (line 334) | def test_save_load_env_cnn(tmp_path, model_class):
function test_save_load_replay_buffer (line 363) | def test_save_load_replay_buffer(tmp_path, model_class):
function test_warn_buffer (line 400) | def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
function test_save_load_policy (line 443) | def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
function test_save_load_q_net (line 548) | def test_save_load_q_net(tmp_path, model_class, policy_str):
function test_open_file_str_pathlib (line 623) | def test_open_file_str_pathlib(tmp_path, pathtype):
function test_open_file (line 671) | def test_open_file(tmp_path):
function test_save_load_large_model (line 706) | def test_save_load_large_model(tmp_path):
function test_load_invalid_object (line 726) | def test_load_invalid_object(tmp_path):
function test_dqn_target_update_interval (line 756) | def test_dqn_target_update_interval(tmp_path):
function test_no_resource_warning (line 769) | def test_no_resource_warning(tmp_path):
function test_cast_lr_schedule (line 802) | def test_cast_lr_schedule(tmp_path):
function test_save_load_net_arch_none (line 816) | def test_save_load_net_arch_none(tmp_path):
function test_save_load_no_target_params (line 828) | def test_save_load_no_target_params(tmp_path):
function test_save_load_backward_compatible (line 841) | def test_save_load_backward_compatible(tmp_path, model_class):
function test_save_load_clip_range_portable (line 864) | def test_save_load_clip_range_portable(tmp_path, model_class):
FILE: tests/test_sde.py
function test_state_dependent_exploration_grad (line 10) | def test_state_dependent_exploration_grad():
function test_sde_check (line 59) | def test_sde_check():
function test_only_sde_squashed (line 64) | def test_only_sde_squashed():
function test_state_dependent_noise (line 72) | def test_state_dependent_noise(model_class, use_expln, squash_output):
class StoreActionEnvWrapper (line 109) | class StoreActionEnvWrapper(gym.Wrapper):
method __init__ (line 114) | def __init__(self, env):
method step (line 119) | def step(self, action):
FILE: tests/test_spaces.py
class DummyEnv (line 19) | class DummyEnv(gym.Env):
method step (line 23) | def step(self, action):
method reset (line 26) | def reset(self, *, seed: int | None = None, options: dict | None = None):
class DummyMultidimensionalAction (line 32) | class DummyMultidimensionalAction(DummyEnv):
method __init__ (line 33) | def __init__(self):
class DummyMultiBinary (line 40) | class DummyMultiBinary(DummyEnv):
method __init__ (line 41) | def __init__(self, n):
class DummyMultiDiscreteSpace (line 48) | class DummyMultiDiscreteSpace(DummyEnv):
method __init__ (line 49) | def __init__(self, nvec):
function test_env (line 65) | def test_env(env):
function test_identity_spaces (line 72) | def test_identity_spaces(model_class, env):
function test_action_spaces (line 91) | def test_action_spaces(model_class, env):
function test_sde_multi_dim (line 112) | def test_sde_multi_dim():
function test_discrete_obs_space (line 125) | def test_discrete_obs_space(model_class, env):
function test_float64_action_space (line 152) | def test_float64_action_space(model_class, obs_space, action_space):
function test_multidim_binary_not_supported (line 172) | def test_multidim_binary_not_supported():
FILE: tests/test_tensorboard.py
class HParamCallback (line 20) | class HParamCallback(BaseCallback):
method _on_training_start (line 25) | def _on_training_start(self) -> None:
method _on_step (line 44) | def _on_step(self) -> bool:
function test_tensorboard (line 49) | def test_tensorboard(tmp_path, model_name):
function test_escape_log_name (line 76) | def test_escape_log_name(tmp_path):
FILE: tests/test_train_eval_mode.py
class FlattenBatchNormDropoutExtractor (line 20) | class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
method __init__ (line 28) | def __init__(self, observation_space: gym.Space):
method forward (line 37) | def forward(self, observations: th.Tensor) -> th.Tensor:
function clone_batch_norm_stats (line 44) | def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th...
function clone_dqn_batch_norm_stats (line 54) | def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th....
function clone_td3_batch_norm_stats (line 70) | def clone_td3_batch_norm_stats(
function clone_sac_batch_norm_stats (line 103) | def clone_sac_batch_norm_stats(
function clone_on_policy_batch_norm (line 124) | def clone_on_policy_batch_norm(model: A2C | PPO) -> (th.Tensor, th.Tensor):
function test_dqn_train_with_batch_norm (line 137) | def test_dqn_train_with_batch_norm():
function test_td3_train_with_batch_norm (line 178) | def test_td3_train_with_batch_norm():
function test_sac_train_with_batch_norm (line 227) | def test_sac_train_with_batch_norm():
function test_a2c_ppo_train_with_batch_norm (line 271) | def test_a2c_ppo_train_with_batch_norm(model_class, env_id):
function test_offpolicy_collect_rollout_batch_norm (line 290) | def test_offpolicy_collect_rollout_batch_norm(model_class):
function test_a2c_ppo_collect_rollouts_with_batch_norm (line 322) | def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id):
function test_predict_with_dropout_batch_norm (line 346) | def test_predict_with_dropout_batch_norm(model_class, env_id):
FILE: tests/test_utils.py
function test_make_vec_env (line 41) | def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class):
function test_make_vec_env_func_checker (line 58) | def test_make_vec_env_func_checker():
function test_make_atari_env (line 76) | def test_make_atari_env(
function test_vec_env_kwargs (line 119) | def test_vec_env_kwargs():
function test_vec_env_wrapper_kwargs (line 124) | def test_vec_env_wrapper_kwargs():
function test_vec_env_monitor_kwargs (line 129) | def test_vec_env_monitor_kwargs():
function test_env_auto_monitor_wrap (line 148) | def test_env_auto_monitor_wrap():
function test_custom_vec_env (line 161) | def test_custom_vec_env(tmp_path):
function test_evaluate_policy (line 189) | def test_evaluate_policy(direct_policy):
class ZeroRewardWrapper (line 237) | class ZeroRewardWrapper(gym.RewardWrapper):
method reward (line 238) | def reward(self, reward):
class AlwaysDoneWrapper (line 242) | class AlwaysDoneWrapper(gym.Wrapper):
method __init__ (line 245) | def __init__(self, env):
method step (line 250) | def step(self, action):
method reset (line 256) | def reset(self, **kwargs):
function test_evaluate_vector_env (line 266) | def test_evaluate_vector_env(n_envs):
function test_evaluate_policy_monitors (line 289) | def test_evaluate_policy_monitors(vec_env_class):
function test_vec_noise (line 354) | def test_vec_noise():
function test_get_parameters_by_name (line 387) | def test_get_parameters_by_name():
function test_polyak (line 403) | def test_polyak():
function test_zip_strict (line 416) | def test_zip_strict():
function test_is_wrapped (line 434) | def test_is_wrapped():
function test_get_system_info (line 447) | def test_get_system_info():
function test_is_vectorized_observation (line 457) | def test_is_vectorized_observation():
function test_policy_is_vectorized_obs (line 533) | def test_policy_is_vectorized_obs():
function test_check_shape_equal (line 582) | def test_check_shape_equal():
function test_deprecated_schedules (line 602) | def test_deprecated_schedules():
FILE: tests/test_vec_check_nan.py
class NanAndInfEnv (line 9) | class NanAndInfEnv(gym.Env):
method __init__ (line 14) | def __init__(self):
method step (line 20) | def step(action):
method reset (line 30) | def reset(seed=None):
method render (line 33) | def render(self):
function test_check_nan (line 37) | def test_check_nan():
FILE: tests/test_vec_envs.py
class CustomGymEnv (line 30) | class CustomGymEnv(gym.Env):
method __init__ (line 31) | def __init__(self, space, render_mode: str = "rgb_array"):
method reset (line 42) | def reset(self, *, seed: int | None = None, options: dict | None = None):
method step (line 50) | def step(self, action):
method _choose_next_state (line 58) | def _choose_next_state(self):
method render (line 61) | def render(self):
method seed (line 65) | def seed(self, seed=None):
method custom_method (line 71) | def custom_method(dim_0=1, dim_1=1):
function test_vecenv_func_checker (line 83) | def test_vecenv_func_checker():
function test_vecenv_custom_calls (line 95) | def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
class StepEnv (line 218) | class StepEnv(gym.Env):
method __init__ (line 219) | def __init__(self, max_steps):
method reset (line 227) | def reset(self, *, seed: int | None = None, options: dict | None = None):
method step (line 231) | def step(self, action):
function test_vecenv_terminal_obs (line 241) | def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
function check_vecenv_spaces (line 294) | def check_vecenv_spaces(vec_env_class, space, obs_assert):
function check_vecenv_obs (line 312) | def check_vecenv_obs(obs, space):
function test_vecenv_single_space (line 321) | def test_vecenv_single_space(vec_env_class, space):
class _UnorderedDictSpace (line 328) | class _UnorderedDictSpace(spaces.Dict):
method sample (line 331) | def sample(self):
function test_vecenv_dict_spaces (line 336) | def test_vecenv_dict_spaces(vec_env_class):
function test_vecenv_tuple_spaces (line 354) | def test_vecenv_tuple_spaces(vec_env_class):
function test_subproc_start_method (line 367) | def test_subproc_start_method():
class CustomWrapperA (line 388) | class CustomWrapperA(VecNormalize):
method __init__ (line 389) | def __init__(self, venv):
class CustomWrapperB (line 394) | class CustomWrapperB(VecNormalize):
method __init__ (line 395) | def __init__(self, venv):
method func_b (line 399) | def func_b(self):
method name_test (line 402) | def name_test(self):
class CustomWrapperBB (line 406) | class CustomWrapperBB(CustomWrapperB):
method __init__ (line 407) | def __init__(self, venv):
function test_vecenv_wrapper_getattr (line 412) | def test_vecenv_wrapper_getattr():
function test_framestack_vecenv (line 432) | def test_framestack_vecenv():
function test_vec_env_is_wrapped (line 506) | def test_vec_env_is_wrapped():
function test_vec_deterministic (line 531) | def test_vec_deterministic(vec_env_class):
function test_vec_seeding (line 561) | def test_vec_seeding(vec_env_class):
function test_render (line 592) | def test_render(vec_env_class):
function test_video_recorder (line 661) | def test_video_recorder(tmp_path):
FILE: tests/test_vec_extract_dict_obs.py
class DictObsVecEnv (line 8) | class DictObsVecEnv(VecEnv):
method __init__ (line 13) | def __init__(self):
method step_async (line 21) | def step_async(self, actions):
method step_wait (line 24) | def step_wait(self):
method reset (line 41) | def reset(self):
method render (line 45) | def render(self, mode=""):
method get_attr (line 48) | def get_attr(self, attr_name, indices=None):
method close (line 52) | def close(self):
method env_is_wrapped (line 55) | def env_is_wrapped(self, wrapper_class, indices=None):
method env_method (line 59) | def env_method(self):
method set_attr (line 62) | def set_attr(self, attr_name, value, indices=None) -> None:
function test_extract_dict_obs (line 66) | def test_extract_dict_obs():
function test_vec_with_ppo (line 84) | def test_vec_with_ppo():
FILE: tests/test_vec_monitor.py
function test_vec_monitor (line 18) | def test_vec_monitor(tmp_path):
function test_vec_monitor_info_keywords (line 51) | def test_vec_monitor_info_keywords(tmp_path):
function test_vec_monitor_load_results (line 82) | def test_vec_monitor_load_results(tmp_path):
function test_vec_monitor_ppo (line 132) | def test_vec_monitor_ppo(recwarn):
function test_vec_monitor_warn (line 148) | def test_vec_monitor_warn():
FILE: tests/test_vec_normalize.py
class DummyRewardEnv (line 24) | class DummyRewardEnv(gym.Env):
method __init__ (line 27) | def __init__(self, return_reward_idx=0):
method step (line 34) | def step(self, action):
method reset (line 42) | def reset(self, *, seed: int | None = None, options: dict | None = None):
class DummyDictEnv (line 49) | class DummyDictEnv(gym.Env):
method __init__ (line 54) | def __init__(self):
method reset (line 65) | def reset(self, *, seed: int | None = None, options: dict | None = None):
method step (line 70) | def step(self, action):
method compute_reward (line 76) | def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.n...
class DummyMixedDictEnv (line 81) | class DummyMixedDictEnv(gym.Env):
method __init__ (line 86) | def __init__(self):
method reset (line 97) | def reset(self, *, seed: int | None = None, options: dict | None = None):
method step (line 102) | def step(self, action):
function allclose (line 108) | def allclose(obs_1, obs_2):
function make_env (line 122) | def make_env():
function make_env_render (line 126) | def make_env_render():
function make_dict_env (line 130) | def make_dict_env():
function make_image_env (line 134) | def make_image_env():
function check_rms_equal (line 138) | def check_rms_equal(rmsa, rmsb):
function check_vec_norm_equal (line 150) | def check_vec_norm_equal(norma, normb):
function _make_warmstart (line 168) | def _make_warmstart(env_fn, **kwargs):
function _make_warmstart_cliffwalking (line 181) | def _make_warmstart_cliffwalking(**kwargs):
function _make_warmstart_cartpole (line 190) | def _make_warmstart_cartpole():
function _make_warmstart_dict_env (line 195) | def _make_warmstart_dict_env(**kwargs):
function test_runningmeanstd (line 200) | def test_runningmeanstd():
function test_combining_stats (line 218) | def test_combining_stats():
function test_obs_rms_vec_normalize (line 250) | def test_obs_rms_vec_normalize():
function test_vec_env (line 269) | def test_vec_env(tmp_path, make_gym_env):
function test_get_original (line 306) | def test_get_original():
function test_get_original_dict (line 325) | def test_get_original_dict():
function test_normalize_external (line 345) | def test_normalize_external():
function test_normalize_dict_selected_keys (line 355) | def test_normalize_dict_selected_keys():
function test_her_normalization (line 370) | def test_her_normalization():
function test_offpolicy_normalization (line 400) | def test_offpolicy_normalization(model_class):
function test_sync_vec_normalize (line 421) | def test_sync_vec_normalize(make_env):
function test_discrete_obs (line 479) | def test_discrete_obs():
function test_non_dict_obs_keys (line 487) | def test_non_dict_obs_keys():
FILE: tests/test_vec_stacked_obs.py
function test_compute_stacking_box (line 12) | def test_compute_stacking_box():
function test_compute_stacking_multidim_box (line 21) | def test_compute_stacking_multidim_box():
function test_compute_stacking_multidim_box_channel_first (line 30) | def test_compute_stacking_multidim_box_channel_first():
function test_compute_stacking_image_channel_first (line 41) | def test_compute_stacking_image_channel_first():
function test_compute_stacking_image_channel_last (line 51) | def test_compute_stacking_image_channel_last():
function test_compute_stacking_image_channel_first_stack_last (line 61) | def test_compute_stacking_image_channel_first_stack_last():
function test_compute_stacking_image_channel_last_stack_first (line 73) | def test_compute_stacking_image_channel_last_stack_first():
function test_reset_update_box (line 85) | def test_reset_update_box():
function test_reset_update_multidim_box (line 106) | def test_reset_update_multidim_box():
function test_reset_update_multidim_box_channel_first (line 127) | def test_reset_update_multidim_box_channel_first():
function test_reset_update_image_channel_first (line 146) | def test_reset_update_image_channel_first():
function test_reset_update_image_channel_last (line 165) | def test_reset_update_image_channel_last():
function test_reset_update_image_channel_first_stack_last (line 186) | def test_reset_update_image_channel_first_stack_last():
function test_reset_update_image_channel_last_stack_first (line 207) | def test_reset_update_image_channel_last_stack_first():
function test_reset_update_dict (line 226) | def test_reset_update_dict():
function test_episode_termination_box (line 271) | def test_episode_termination_box():
function test_episode_termination_dict (line 292) | def test_episode_termination_dict():
Condensed preview — 170 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,398K chars).
[
{
"path": ".github/ISSUE_TEMPLATE/bug_report.yml",
"chars": 3192,
"preview": "name: \"\\U0001F41B Bug Report\"\ndescription: If you encounter an unexpected behavior, software crash, or other bug.\ntitle:"
},
{
"path": ".github/ISSUE_TEMPLATE/custom_env.yml",
"chars": 4472,
"preview": "name: \"\\U0001F916 Custom Gym Environment Issue\"\ndescription: If your problem involves a custom gym environment.\nlabels: "
},
{
"path": ".github/ISSUE_TEMPLATE/documentation.yml",
"chars": 1156,
"preview": "name: \"\\U0001F4DA Documentation\"\ndescription: If you want to improve the documentation by reporting errors, inconsistenc"
},
{
"path": ".github/ISSUE_TEMPLATE/feature_request.yml",
"chars": 1926,
"preview": "name: \"\\U0001F680 Feature Request\"\ndescription: If you have an idea for a new feature or an improvement.\ntitle: \"[Featur"
},
{
"path": ".github/ISSUE_TEMPLATE/question.yml",
"chars": 1914,
"preview": "name: \"❓ Question\"\ndescription: If you have a general question about Stable-Baselines3.\ntitle: \"[Question] question titl"
},
{
"path": ".github/PULL_REQUEST_TEMPLATE.md",
"chars": 2359,
"preview": "<!--- Provide a general summary of your changes in the Title above -->\n\n## Description\n<!--- Describe your changes in de"
},
{
"path": ".github/workflows/ci.yml",
"chars": 2272,
"preview": "# This workflow will install Python dependencies, run tests and lint with a variety of Python versions\n# For more inform"
},
{
"path": ".gitignore",
"chars": 424,
"preview": "*.swp\n*.pyc\n*.pkl\n*.py~\n*.bak\n.pytest_cache\n.mypy_cache\n.DS_Store\n.idea\n.vscode\n.coverage\n.coverage.*\n__pycache__/\n_buil"
},
{
"path": ".readthedocs.yml",
"chars": 460,
"preview": "# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Requir"
},
{
"path": "CITATION.bib",
"chars": 418,
"preview": "@article{stable-baselines3,\n author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximil"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 5239,
"preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participa"
},
{
"path": "CONTRIBUTING.md",
"chars": 4277,
"preview": "## Contributing to Stable-Baselines3\n\n**Important: When submitting issues or pull requests, the use of LLM or code assis"
},
{
"path": "Dockerfile",
"chars": 1043,
"preview": "ARG PARENT_IMAGE=mambaorg/micromamba:2.0-ubuntu24.04\nFROM $PARENT_IMAGE\nARG PYTORCH_DEPS=https://download.pytorch.org/wh"
},
{
"path": "LICENSE",
"chars": 1075,
"preview": "The MIT License\n\nCopyright (c) 2019 Antonin Raffin\n\nPermission is hereby granted, free of charge, to any person obtainin"
},
{
"path": "Makefile",
"chars": 1468,
"preview": "SHELL=/bin/bash\nLINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py\n\npytest:\n\t./scripts/run_tests.sh\n\nmypy:\n\tmypy"
},
{
"path": "NOTICE",
"chars": 1338,
"preview": "Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baseli"
},
{
"path": "README.md",
"chars": 15661,
"preview": "<!-- [](https://gitlab.com/ar"
},
{
"path": "docs/Makefile",
"chars": 688,
"preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\n# For debug: SPHINXO"
},
{
"path": "docs/README.md",
"chars": 388,
"preview": "# Stable Baselines3 Documentation\n\nThis folder contains documentation for the RL baselines.\n\n\n### Build the Documentatio"
},
{
"path": "docs/_static/css/baselines_theme.css",
"chars": 1186,
"preview": "/* Main colors adapted from pytorch doc */\n:root{\n --main-bg-color: #343A40;\n --link-color: #FD7E14;\n}\n\n/* Header fon"
},
{
"path": "docs/common/atari_wrappers.md",
"chars": 122,
"preview": "(atari-wrapper)=\n\n# Atari Wrappers\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.atari_wrappers\n :members:\n``"
},
{
"path": "docs/common/distributions.md",
"chars": 1056,
"preview": "(distributions)=\n\n# Probability Distributions\n\nProbability distributions used for the different action spaces:\n\n- `Categ"
},
{
"path": "docs/common/env_checker.md",
"chars": 126,
"preview": "(env-checker)=\n\n# Gym Environment Checker\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.env_checker\n :members"
},
{
"path": "docs/common/env_util.md",
"chars": 115,
"preview": "(env-util)=\n\n# Environments Utils\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.env_util\n :members:\n```\n"
},
{
"path": "docs/common/envs.md",
"chars": 320,
"preview": "(envs)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.envs\n\n\n```\n\n# Custom Environments\n\nThose environments we"
},
{
"path": "docs/common/evaluation.md",
"chars": 112,
"preview": "(eval)=\n\n# Evaluation Helper\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.evaluation\n :members:\n```\n"
},
{
"path": "docs/common/logger.md",
"chars": 5250,
"preview": "(logger)=\n\n# Logger\n\nTo overwrite the default logger, you can pass one to the algorithm.\nAvailable formats are `[\"stdout"
},
{
"path": "docs/common/monitor.md",
"chars": 110,
"preview": "(monitor)=\n\n# Monitor Wrapper\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.monitor\n :members:\n```\n"
},
{
"path": "docs/common/noise.md",
"chars": 103,
"preview": "(noise)=\n\n# Action Noise\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.noise\n :members:\n```\n"
},
{
"path": "docs/common/utils.md",
"chars": 96,
"preview": "(utils)=\n\n# Utils\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.utils\n :members:\n```\n"
},
{
"path": "docs/conda_env.yml",
"chars": 377,
"preview": "name: root\nchannels:\n - pytorch\n - conda-forge\ndependencies:\n - cpuonly=1.0=0\n - pip=24.2\n - python=3.11\n - pytorc"
},
{
"path": "docs/conf.py",
"chars": 6889,
"preview": "#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a selection of the most com"
},
{
"path": "docs/guide/algos.md",
"chars": 5302,
"preview": "# RL Algorithms\n\nThis table displays the RL algorithms that are implemented in the Stable Baselines3 project,\nalong with"
},
{
"path": "docs/guide/callbacks.md",
"chars": 14700,
"preview": "(callbacks)=\n\n# Callbacks\n\nA callback is a set of functions that will be called at given stages of the training procedur"
},
{
"path": "docs/guide/checking_nan.md",
"chars": 5199,
"preview": "# Dealing with NaNs and infs\n\nDuring the training of a model on a given environment, it is possible that the RL model be"
},
{
"path": "docs/guide/custom_env.md",
"chars": 5810,
"preview": "(custom-env)=\n\n# Using Custom Environments\n\nTo use the RL baselines with custom environments, they just need to follow t"
},
{
"path": "docs/guide/custom_policy.md",
"chars": 15059,
"preview": "(custom-policy)=\n\n# Policy Networks\n\nStable Baselines3 provides policy networks for images (CnnPolicies),\nother type of "
},
{
"path": "docs/guide/developer.md",
"chars": 4252,
"preview": "(developer)=\n\n# Developer Guide\n\nThis guide is meant for those who want to understand the internals and the design choic"
},
{
"path": "docs/guide/examples.md",
"chars": 30335,
"preview": "---\nmyst:\n substitutions:\n colab: |-\n ```{image} ../_static/img/colab.svg\n ```\n---\n\n(examples)=\n\n# Example"
},
{
"path": "docs/guide/export.md",
"chars": 15844,
"preview": "(export)=\n\n# Exporting models\n\nAfter training an agent, you may want to deploy/use it in another language\nor framework, "
},
{
"path": "docs/guide/imitation.md",
"chars": 835,
"preview": "(imitation)=\n\n# Imitation Learning\n\nThe [imitation](https://github.com/HumanCompatibleAI/imitation) library implements\ni"
},
{
"path": "docs/guide/install.md",
"chars": 4783,
"preview": "(install)=\n\n# Installation\n\n## Prerequisites\n\nStable-Baselines3 requires python 3.10+ and PyTorch >= 2.3\n\n### Windows\n\nW"
},
{
"path": "docs/guide/integrations.md",
"chars": 8005,
"preview": "(integrations)=\n\n# Integrations\n\n## Weights & Biases\n\nWeights & Biases provides a callback for experiment tracking that "
},
{
"path": "docs/guide/migration.md",
"chars": 8895,
"preview": "(migration)=\n\n# Migrating from Stable-Baselines\n\nThis is a guide to migrate from Stable-Baselines (SB2) to Stable-Baseli"
},
{
"path": "docs/guide/plotting.md",
"chars": 5974,
"preview": "(plotting)=\n\n# Plotting\n\nStable Baselines3 provides utilities for plotting training results, allowing you to monitor and"
},
{
"path": "docs/guide/quickstart.md",
"chars": 1372,
"preview": "(quickstart)=\n\n# Getting Started\n\n:::{note}\nStable-Baselines3 (SB3) uses [vectorized environments (VecEnv)](vec_envs.md)"
},
{
"path": "docs/guide/rl.md",
"chars": 1008,
"preview": "(rl)=\n\n# Reinforcement Learning Resources\n\nStable-Baselines3 assumes that you already understand the basic concepts of R"
},
{
"path": "docs/guide/rl_tips.md",
"chars": 15163,
"preview": "(rl-tips)=\n\n# Reinforcement Learning Tips and Tricks\n\nThe aim of this section is to help you run reinforcement learning "
},
{
"path": "docs/guide/rl_zoo.md",
"chars": 3179,
"preview": "(rl-zoo)=\n\n# RL Baselines3 Zoo\n\n[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework"
},
{
"path": "docs/guide/save_format.md",
"chars": 2865,
"preview": "(save-format)=\n\n# On saving and loading\n\nStable Baselines3 (SB3) stores both neural network parameters and algorithm-rel"
},
{
"path": "docs/guide/sb3_contrib.md",
"chars": 3239,
"preview": "(sb3-contrib)=\n\n# SB3 Contrib\n\nWe implement experimental features in a separate contrib repository:\n[SB3-Contrib]\n\nThis "
},
{
"path": "docs/guide/sbx.md",
"chars": 2101,
"preview": "(sbx)=\n\n# Stable Baselines Jax (SBX)\n\n[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept"
},
{
"path": "docs/guide/tensorboard.md",
"chars": 12675,
"preview": "(tensorboard)=\n\n# Tensorboard Integration\n\n## Basic Usage\n\nTo use Tensorboard with stable baselines3, you simply need to"
},
{
"path": "docs/guide/vec_envs.md",
"chars": 12762,
"preview": "(vec-env)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.vec_env\n```\n\n# Vectorized Environments\n\nVectorized En"
},
{
"path": "docs/index.rst",
"chars": 3626,
"preview": ".. Stable Baselines3 documentation master file, created by\n sphinx-quickstart on Thu Sep 26 11:06:54 2019.\n You can "
},
{
"path": "docs/make.bat",
"chars": 819,
"preview": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sp"
},
{
"path": "docs/misc/changelog.md",
"chars": 82040,
"preview": "(changelog)=\n\n# Changelog\n\n## Release 2.8.0a3 (WIP)\n\n### Breaking Changes:\n\n- Removed support for Python 3.9, please upg"
},
{
"path": "docs/misc/projects.md",
"chars": 14873,
"preview": "(projects)=\n\n# Projects\n\nThis is a list of projects using stable-baselines3.\nPlease tell us, if you want your project to"
},
{
"path": "docs/modules/a2c.md",
"chars": 6079,
"preview": "(a2c)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.a2c\n\n```\n\n# A2C\n\nA synchronous, deterministic variant of [Asynch"
},
{
"path": "docs/modules/base.md",
"chars": 674,
"preview": "(base-algo)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.common.base_class\n\n```\n\n# Base RL Class\n\nCommon interface "
},
{
"path": "docs/modules/ddpg.md",
"chars": 4215,
"preview": "(ddpg)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.ddpg\n\n```\n\n# DDPG\n\n[Deep Deterministic Policy Gradient (DDPG)]("
},
{
"path": "docs/modules/dqn.md",
"chars": 3245,
"preview": "(dqn)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.dqn\n\n```\n\n# DQN\n\n[Deep Q Network (DQN)](https://arxiv.org/abs/13"
},
{
"path": "docs/modules/her.md",
"chars": 4268,
"preview": "(her)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.her\n\n```\n\n# HER\n\n[Hindsight Experience Replay (HER)](https://arx"
},
{
"path": "docs/modules/ppo.md",
"chars": 6907,
"preview": "(ppo2)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.ppo\n```\n\n# PPO\n\nThe [Proximal Policy Optimization](https://arxi"
},
{
"path": "docs/modules/sac.md",
"chars": 5729,
"preview": "(sac)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.sac\n\n```\n\n# SAC\n\n[Soft Actor Critic (SAC)](https://spinningup.op"
},
{
"path": "docs/modules/td3.md",
"chars": 4415,
"preview": "(td3)=\n\n```{eval-rst}\n.. automodule:: stable_baselines3.td3\n\n```\n\n# TD3\n\n[Twin Delayed DDPG (TD3)](https://spinningup.op"
},
{
"path": "docs/spelling_wordlist.txt",
"chars": 1029,
"preview": "py\nenv\natari\nargparse\nArgparse\nTensorFlow\nfeedforward\nenvs\nVecEnv\npretrain\npetrained\ntf\nth\nnn\nnp\nstr\nmujoco\ncpu\nndarray\n"
},
{
"path": "pyproject.toml",
"chars": 1984,
"preview": "[tool.ruff]\n# Same as Black.\nline-length = 127\n# Assume Python 3.10\ntarget-version = \"py310\"\n\n[tool.ruff.lint]\n# See htt"
},
{
"path": "scripts/build_docker.sh",
"chars": 799,
"preview": "#!/bin/bash\n\nCPU_PARENT=mambaorg/micromamba:2.0-ubuntu24.04\nGPU_PARENT=mambaorg/micromamba:2.0-cuda12.6.3-ubuntu24.04\n\nT"
},
{
"path": "scripts/run_docker_cpu.sh",
"chars": 359,
"preview": "#!/bin/bash\n# Launch an experiment using the docker cpu image\n\ncmd_line=\"$@\"\n\necho \"Executing in the docker (cpu image):"
},
{
"path": "scripts/run_docker_gpu.sh",
"chars": 419,
"preview": "#!/bin/bash\n# Launch an experiment using the docker gpu image\ncmd_line=\"$@\"\necho \"Executing in the docker (gpu image):\"\n"
},
{
"path": "scripts/run_tests.sh",
"chars": 133,
"preview": "#!/bin/bash\npython3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m \"no"
},
{
"path": "setup.py",
"chars": 5455,
"preview": "import os\n\nfrom setuptools import find_packages, setup\n\nwith open(os.path.join(\"stable_baselines3\", \"version.txt\")) as f"
},
{
"path": "stable_baselines3/__init__.py",
"chars": 939,
"preview": "import os\n\nfrom stable_baselines3.a2c import A2C\nfrom stable_baselines3.common.utils import get_system_info\nfrom stable_"
},
{
"path": "stable_baselines3/a2c/__init__.py",
"chars": 189,
"preview": "from stable_baselines3.a2c.a2c import A2C\nfrom stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPol"
},
{
"path": "stable_baselines3/a2c/a2c.py",
"chars": 9164,
"preview": "from typing import Any, ClassVar, TypeVar\n\nimport torch as th\nfrom gymnasium import spaces\nfrom torch.nn import function"
},
{
"path": "stable_baselines3/a2c/policies.py",
"chars": 301,
"preview": "# This file is here just to define MlpPolicy/CnnPolicy\n# that work for A2C\nfrom stable_baselines3.common.policies import"
},
{
"path": "stable_baselines3/common/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "stable_baselines3/common/atari_wrappers.py",
"chars": 11770,
"preview": "from typing import SupportsFloat\n\nimport gymnasium as gym\nimport numpy as np\nfrom gymnasium import spaces\n\nfrom stable_b"
},
{
"path": "stable_baselines3/common/base_class.py",
"chars": 38680,
"preview": "\"\"\"Abstract base classes for RL algorithms.\"\"\"\n\nimport io\nimport pathlib\nimport time\nimport warnings\nfrom abc import ABC"
},
{
"path": "stable_baselines3/common/buffers.py",
"chars": 40735,
"preview": "import warnings\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Generator\nfrom typing import Any\n\nimport"
},
{
"path": "stable_baselines3/common/callbacks.py",
"chars": 27529,
"preview": "import os\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable\nfrom typing import TY"
},
{
"path": "stable_baselines3/common/distributions.py",
"chars": 28170,
"preview": "\"\"\"Probability distributions.\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Optional, TypeVar\n\nimport "
},
{
"path": "stable_baselines3/common/env_checker.py",
"chars": 24677,
"preview": "import warnings\nfrom typing import Any\n\nimport gymnasium as gym\nimport numpy as np\nfrom gymnasium import spaces\n\nfrom st"
},
{
"path": "stable_baselines3/common/env_util.py",
"chars": 7932,
"preview": "import os\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport gymnasium as gym\n\nfrom stable_baselines3.c"
},
{
"path": "stable_baselines3/common/envs/__init__.py",
"chars": 533,
"preview": "from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv\nfrom stable_baselines3.common.envs.identity_en"
},
{
"path": "stable_baselines3/common/envs/bit_flipping_env.py",
"chars": 9255,
"preview": "from collections import OrderedDict\nfrom typing import Any\n\nimport numpy as np\nfrom gymnasium import Env, spaces\nfrom gy"
},
{
"path": "stable_baselines3/common/envs/identity_env.py",
"chars": 5975,
"preview": "from typing import Any, Generic, TypeVar\n\nimport gymnasium as gym\nimport numpy as np\nfrom gymnasium import spaces\n\nfrom "
},
{
"path": "stable_baselines3/common/envs/multi_input_envs.py",
"chars": 6414,
"preview": "import gymnasium as gym\nimport numpy as np\nfrom gymnasium import spaces\n\nfrom stable_baselines3.common.type_aliases impo"
},
{
"path": "stable_baselines3/common/evaluation.py",
"chars": 6620,
"preview": "import warnings\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport gymnasium as gym\nimport numpy as np\n"
},
{
"path": "stable_baselines3/common/logger.py",
"chars": 24301,
"preview": "import datetime\nimport json\nimport os\nimport sys\nimport tempfile\nimport warnings\nfrom collections import defaultdict\nfro"
},
{
"path": "stable_baselines3/common/monitor.py",
"chars": 9352,
"preview": "__all__ = [\"Monitor\", \"ResultsWriter\", \"get_monitor_files\", \"load_results\"]\n\nimport csv\nimport json\nimport os\nimport tim"
},
{
"path": "stable_baselines3/common/noise.py",
"chars": 5528,
"preview": "import copy\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Iterable\n\nimport numpy as np\nfrom numpy.typi"
},
{
"path": "stable_baselines3/common/off_policy_algorithm.py",
"chars": 27450,
"preview": "import io\nimport pathlib\nimport sys\nimport time\nimport warnings\nfrom copy import deepcopy\nfrom typing import Any, TypeVa"
},
{
"path": "stable_baselines3/common/on_policy_algorithm.py",
"chars": 14576,
"preview": "import sys\nimport time\nimport warnings\nfrom typing import Any, TypeVar\n\nimport numpy as np\nimport torch as th\nfrom gymna"
},
{
"path": "stable_baselines3/common/policies.py",
"chars": 42872,
"preview": "\"\"\"Policies: abstract base class and concrete implementations.\"\"\"\n\nimport collections\nimport copy\nimport warnings\nfrom a"
},
{
"path": "stable_baselines3/common/preprocessing.py",
"chars": 8835,
"preview": "import warnings\n\nimport numpy as np\nimport torch as th\nfrom gymnasium import spaces\nfrom torch.nn import functional as F"
},
{
"path": "stable_baselines3/common/results_plotter.py",
"chars": 4510,
"preview": "from collections.abc import Callable\n\nimport numpy as np\nimport pandas as pd\n\n# import matplotlib\n# matplotlib.use('TkAg"
},
{
"path": "stable_baselines3/common/running_mean_std.py",
"chars": 1987,
"preview": "import numpy as np\n\n\nclass RunningMeanStd:\n def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()):\n "
},
{
"path": "stable_baselines3/common/save_util.py",
"chars": 20769,
"preview": "\"\"\"\nSave util taken from stable_baselines\nused to serialize data (class parameters) of model classes\n\"\"\"\n\nimport base64\n"
},
{
"path": "stable_baselines3/common/sb2_compat/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "stable_baselines3/common/sb2_compat/rmsprop_tf_like.py",
"chars": 5656,
"preview": "from collections.abc import Callable, Iterable\nfrom typing import Any\n\nimport torch\nfrom torch.optim import Optimizer\n\n\n"
},
{
"path": "stable_baselines3/common/torch_layers.py",
"chars": 15500,
"preview": "import gymnasium as gym\nimport torch as th\nfrom gymnasium import spaces\nfrom torch import nn\n\nfrom stable_baselines3.com"
},
{
"path": "stable_baselines3/common/type_aliases.py",
"chars": 3312,
"preview": "\"\"\"Common aliases for type hints\"\"\"\n\nfrom collections.abc import Callable\nfrom enum import Enum\nfrom typing import TYPE_"
},
{
"path": "stable_baselines3/common/utils.py",
"chars": 24184,
"preview": "import glob\nimport os\nimport platform\nimport random\nimport re\nimport warnings\nfrom collections import deque\nfrom collect"
},
{
"path": "stable_baselines3/common/vec_env/__init__.py",
"chars": 4360,
"preview": "from copy import deepcopy\nfrom typing import TypeVar\n\nfrom stable_baselines3.common.vec_env.base_vec_env import Cloudpic"
},
{
"path": "stable_baselines3/common/vec_env/base_vec_env.py",
"chars": 19099,
"preview": "import inspect\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Iterable, Sequence\nfrom c"
},
{
"path": "stable_baselines3/common/vec_env/dummy_vec_env.py",
"chars": 6980,
"preview": "import warnings\nfrom collections import OrderedDict\nfrom collections.abc import Callable, Sequence\nfrom copy import deep"
},
{
"path": "stable_baselines3/common/vec_env/patch_gym.py",
"chars": 3430,
"preview": "import warnings\nfrom inspect import signature\nfrom typing import Union\n\nimport gymnasium\n\ntry:\n import gym\n\n gym_i"
},
{
"path": "stable_baselines3/common/vec_env/stacked_observations.py",
"chars": 8147,
"preview": "import warnings\nfrom collections.abc import Mapping\nfrom typing import Any, Generic, TypeVar\n\nimport numpy as np\nfrom gy"
},
{
"path": "stable_baselines3/common/vec_env/subproc_vec_env.py",
"chars": 11299,
"preview": "import multiprocessing as mp\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom typing import Any\n\nimpo"
},
{
"path": "stable_baselines3/common/vec_env/util.py",
"chars": 2656,
"preview": "\"\"\"\nHelpers for dealing with vectorized environments.\n\"\"\"\n\nfrom typing import Any\n\nimport numpy as np\nfrom gymnasium imp"
},
{
"path": "stable_baselines3/common/vec_env/vec_check_nan.py",
"chars": 4208,
"preview": "import warnings\n\nimport numpy as np\nfrom gymnasium import spaces\n\nfrom stable_baselines3.common.vec_env.base_vec_env imp"
},
{
"path": "stable_baselines3/common/vec_env/vec_extract_dict_obs.py",
"chars": 1194,
"preview": "import numpy as np\nfrom gymnasium import spaces\n\nfrom stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEn"
},
{
"path": "stable_baselines3/common/vec_env/vec_frame_stack.py",
"chars": 2079,
"preview": "from collections.abc import Mapping\nfrom typing import Any\n\nimport numpy as np\nfrom gymnasium import spaces\n\nfrom stable"
},
{
"path": "stable_baselines3/common/vec_env/vec_monitor.py",
"chars": 3851,
"preview": "import time\nimport warnings\n\nimport numpy as np\n\nfrom stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEn"
},
{
"path": "stable_baselines3/common/vec_env/vec_normalize.py",
"chars": 13594,
"preview": "import inspect\nimport pickle\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\nfrom gymnasium import "
},
{
"path": "stable_baselines3/common/vec_env/vec_transpose.py",
"chars": 4503,
"preview": "from copy import deepcopy\n\nimport numpy as np\nfrom gymnasium import spaces\n\nfrom stable_baselines3.common.preprocessing "
},
{
"path": "stable_baselines3/common/vec_env/vec_video_recorder.py",
"chars": 5756,
"preview": "import os\nimport os.path\nfrom collections.abc import Callable\n\nimport numpy as np\nfrom gymnasium import error, logger\n\nf"
},
{
"path": "stable_baselines3/ddpg/__init__.py",
"chars": 194,
"preview": "from stable_baselines3.ddpg.ddpg import DDPG\nfrom stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInpu"
},
{
"path": "stable_baselines3/ddpg/ddpg.py",
"chars": 5835,
"preview": "from typing import Any, TypeVar\n\nimport torch as th\n\nfrom stable_baselines3.common.buffers import ReplayBuffer\nfrom stab"
},
{
"path": "stable_baselines3/ddpg/policies.py",
"chars": 139,
"preview": "# DDPG can be view as a special case of TD3\nfrom stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputP"
},
{
"path": "stable_baselines3/dqn/__init__.py",
"chars": 189,
"preview": "from stable_baselines3.dqn.dqn import DQN\nfrom stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPol"
},
{
"path": "stable_baselines3/dqn/dqn.py",
"chars": 13152,
"preview": "import warnings\nfrom typing import Any, ClassVar, TypeVar\n\nimport numpy as np\nimport torch as th\nfrom gymnasium import s"
},
{
"path": "stable_baselines3/dqn/policies.py",
"chars": 10637,
"preview": "from typing import Any\n\nimport torch as th\nfrom gymnasium import spaces\nfrom torch import nn\n\nfrom stable_baselines3.com"
},
{
"path": "stable_baselines3/her/__init__.py",
"chars": 204,
"preview": "from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy\nfrom stable_baselines3.her.her_replay_bu"
},
{
"path": "stable_baselines3/her/goal_selection_strategy.py",
"chars": 649,
"preview": "from enum import Enum\n\n\nclass GoalSelectionStrategy(Enum):\n \"\"\"\n The strategies for selecting new goals when\n c"
},
{
"path": "stable_baselines3/her/her_replay_buffer.py",
"chars": 18943,
"preview": "import copy\nimport warnings\nfrom typing import Any\n\nimport numpy as np\nimport torch as th\nfrom gymnasium import spaces\n\n"
},
{
"path": "stable_baselines3/ppo/__init__.py",
"chars": 189,
"preview": "from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy\nfrom stable_baselines3.ppo.ppo import "
},
{
"path": "stable_baselines3/ppo/policies.py",
"chars": 301,
"preview": "# This file is here just to define MlpPolicy/CnnPolicy\n# that work for PPO\nfrom stable_baselines3.common.policies import"
},
{
"path": "stable_baselines3/ppo/ppo.py",
"chars": 15220,
"preview": "import warnings\nfrom typing import Any, ClassVar, TypeVar\n\nimport numpy as np\nimport torch as th\nfrom gymnasium import s"
},
{
"path": "stable_baselines3/py.typed",
"chars": 0,
"preview": ""
},
{
"path": "stable_baselines3/sac/__init__.py",
"chars": 189,
"preview": "from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy\nfrom stable_baselines3.sac.sac import "
},
{
"path": "stable_baselines3/sac/policies.py",
"chars": 20595,
"preview": "from typing import Any\n\nimport torch as th\nfrom gymnasium import spaces\nfrom torch import nn\n\nfrom stable_baselines3.com"
},
{
"path": "stable_baselines3/sac/sac.py",
"chars": 16287,
"preview": "from typing import Any, ClassVar, TypeVar\n\nimport numpy as np\nimport torch as th\nfrom gymnasium import spaces\nfrom torch"
},
{
"path": "stable_baselines3/td3/__init__.py",
"chars": 189,
"preview": "from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy\nfrom stable_baselines3.td3.td3 import "
},
{
"path": "stable_baselines3/td3/policies.py",
"chars": 14401,
"preview": "from typing import Any\n\nimport torch as th\nfrom gymnasium import spaces\nfrom torch import nn\n\nfrom stable_baselines3.com"
},
{
"path": "stable_baselines3/td3/td3.py",
"chars": 11502,
"preview": "from typing import Any, ClassVar, TypeVar\n\nimport numpy as np\nimport torch as th\nfrom gymnasium import spaces\nfrom torch"
},
{
"path": "stable_baselines3/version.txt",
"chars": 8,
"preview": "2.8.0a4\n"
},
{
"path": "tests/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tests/test_buffers.py",
"chars": 9960,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\nfrom gymnasium import spaces\n\nfrom stable_ba"
},
{
"path": "tests/test_callbacks.py",
"chars": 10523,
"preview": "import os\nimport shutil\n\nimport gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\n\nfrom stable_baseli"
},
{
"path": "tests/test_cnn.py",
"chars": 13693,
"preview": "import os\nfrom copy import deepcopy\n\nimport numpy as np\nimport pytest\nimport torch as th\nfrom gymnasium import spaces\n\nf"
},
{
"path": "tests/test_custom_policy.py",
"chars": 4630,
"preview": "import pytest\nimport torch as th\nimport torch.nn as nn\n\nfrom stable_baselines3 import A2C, DQN, PPO, SAC, TD3\nfrom stabl"
},
{
"path": "tests/test_deterministic.py",
"chars": 1362,
"preview": "import numpy as np\nimport pytest\n\nfrom stable_baselines3 import A2C, DQN, PPO, SAC, TD3\nfrom stable_baselines3.common.no"
},
{
"path": "tests/test_dict_env.py",
"chars": 11710,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\nfrom gymnasium import spaces\n\nfrom stable_baselines3 import A2C"
},
{
"path": "tests/test_distributions.py",
"chars": 9875,
"preview": "from copy import deepcopy\n\nimport gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\n\nfrom stable_base"
},
{
"path": "tests/test_env_checker.py",
"chars": 8866,
"preview": "from typing import Any\n\nimport gymnasium as gym\nimport numpy as np\nimport pytest\nfrom gymnasium import spaces\n\nfrom stab"
},
{
"path": "tests/test_envs.py",
"chars": 10573,
"preview": "import types\nimport warnings\n\nimport gymnasium as gym\nimport numpy as np\nimport pytest\nfrom gymnasium import spaces\n\nfro"
},
{
"path": "tests/test_gae.py",
"chars": 6749,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\nfrom gymnasium import spaces\n\nfrom stable_ba"
},
{
"path": "tests/test_her.py",
"chars": 16881,
"preview": "import os\nimport pathlib\nimport warnings\nfrom copy import deepcopy\n\nimport numpy as np\nimport pytest\nimport torch as th\n"
},
{
"path": "tests/test_identity.py",
"chars": 1997,
"preview": "import numpy as np\nimport pytest\n\nfrom stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3\nfrom stable_baselines3.com"
},
{
"path": "tests/test_logger.py",
"chars": 20888,
"preview": "import importlib.util\nimport os\nimport sys\nimport time\nfrom collections.abc import Sequence\nfrom io import TextIOBase\nfr"
},
{
"path": "tests/test_monitor.py",
"chars": 5010,
"preview": "import json\nimport os\nimport uuid\nimport warnings\n\nimport gymnasium as gym\nimport pandas\nimport pytest\n\nfrom stable_base"
},
{
"path": "tests/test_n_step_replay.py",
"chars": 6566,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\n\nfrom stable_baselines3 import DQN, SAC, TD3\nfrom stable_baseli"
},
{
"path": "tests/test_predict.py",
"chars": 4032,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\nfrom gymnasium import spaces\n\nfrom stable_ba"
},
{
"path": "tests/test_preprocessing.py",
"chars": 2429,
"preview": "import torch\nfrom gymnasium import spaces\n\nfrom stable_baselines3.common.preprocessing import get_obs_shape, preprocess_"
},
{
"path": "tests/test_run.py",
"chars": 7709,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\n\nfrom stable_baselines3 import A2C, DDPG, DQ"
},
{
"path": "tests/test_save_load.py",
"chars": 32957,
"preview": "import base64\nimport io\nimport json\nimport os\nimport pathlib\nimport tempfile\nimport warnings\nimport zipfile\nfrom collect"
},
{
"path": "tests/test_sde.py",
"chars": 4272,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\nfrom torch.distributions import Normal\n\nfrom"
},
{
"path": "tests/test_spaces.py",
"chars": 5440,
"preview": "from dataclasses import dataclass\n\nimport gymnasium as gym\nimport numpy as np\nimport pytest\nfrom gymnasium import spaces"
},
{
"path": "tests/test_tensorboard.py",
"chars": 2941,
"preview": "import os\n\nimport pytest\n\nfrom stable_baselines3 import A2C, PPO, SAC, TD3\nfrom stable_baselines3.common.callbacks impor"
},
{
"path": "tests/test_train_eval_mode.py",
"chars": 13172,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\nimport torch.nn as nn\n\nfrom stable_baselines"
},
{
"path": "tests/test_utils.py",
"chars": 24377,
"preview": "import os\nimport shutil\n\nimport ale_py\nimport gymnasium as gym\nimport numpy as np\nimport pytest\nimport torch as th\nfrom "
},
{
"path": "tests/test_vec_check_nan.py",
"chars": 1398,
"preview": "import gymnasium as gym\nimport numpy as np\nimport pytest\nfrom gymnasium import spaces\n\nfrom stable_baselines3.common.vec"
},
{
"path": "tests/test_vec_envs.py",
"chars": 24219,
"preview": "import collections\nimport functools\nimport itertools\nimport multiprocessing\nimport os\nimport warnings\n\nimport gymnasium "
},
{
"path": "tests/test_vec_extract_dict_obs.py",
"chars": 2941,
"preview": "import numpy as np\nfrom gymnasium import spaces\n\nfrom stable_baselines3 import PPO\nfrom stable_baselines3.common.vec_env"
},
{
"path": "tests/test_vec_monitor.py",
"chars": 5276,
"preview": "import csv\nimport json\nimport os\nimport uuid\nimport warnings\n\nimport gymnasium as gym\nimport pandas\nimport pytest\n\nfrom "
},
{
"path": "tests/test_vec_normalize.py",
"chars": 17892,
"preview": "import operator\nfrom typing import Any\n\nimport gymnasium as gym\nimport numpy as np\nimport pytest\nfrom gymnasium import s"
},
{
"path": "tests/test_vec_stacked_obs.py",
"chars": 14834,
"preview": "import numpy as np\nfrom gymnasium import spaces\n\nfrom stable_baselines3.common.vec_env.stacked_observations import Stack"
}
]
About this extraction
This page contains the full source code of the DLR-RM/stable-baselines3 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 170 files (1.3 MB), approximately 322.6k tokens, and a symbol index with 1248 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.