
Stable Baselines3 (SB3) 是一组基于 PyTorch 的强化学习算法的可靠实现。它是 Stable Baselines 的下一个主要版本。
你可以在 v1.0 博客文章 或我们的 JMLR 论文 中阅读关于 Stable Baselines3 的详细介绍。
这些算法将使研究社区和行业更容易复现、改进和识别新想法,并为构建项目提供良好的基线。我们期望这些工具能作为基础,便于在其上添加新想法,并作为比较新方法与现有方法的工具。我们也希望这些工具的简洁性能让初学者能够尝试更高级的工具集,而无需深陷于实现细节。
注意:尽管易于使用,但 Stable Baselines3 (SB3) 假定你对强化学习 (RL) 有一定了解。 没有实践经验不应直接使用此库。为此,我们在文档中提供了良好的入门资源。
每个算法的性能都经过了测试(参见各自页面中的 结果 部分),你可以查看 issue #48 和 #49 了解更多细节。
我们还在 OpenRL Benchmark 平台上提供了详细的日志和报告。
| 特性 | Stable-Baselines3 |
|---|---|
| 先进的 RL 方法 | :heavy_check_mark: |
| 文档 | :heavy_check_mark: |
| 自定义环境 | :heavy_check_mark: |
| 自定义策略 | :heavy_check_mark: |
| 通用接口 | :heavy_check_mark: |
支持 Dict 观测空间 |
:heavy_check_mark: |
| 兼容 Ipython / Notebook | :heavy_check_mark: |
| 支持 Tensorboard | :heavy_check_mark: |
| PEP8 代码风格 | :heavy_check_mark: |
| 自定义回调 | :heavy_check_mark: |
| 高代码覆盖率 | :heavy_check_mark: |
| 类型提示 | :heavy_check_mark: |
由于原始路线图中的大部分功能已经实现,SB3 目前没有重大的变更计划,现已进入稳定阶段。
如果你想贡献,可以在 issue 中搜索标记为需要帮助的议题以及其他建议的改进。
虽然 SB3 的开发现在主要集中在错误修复和维护(文档更新、用户体验等),但在相关的代码库中有更活跃的开发:
- 新算法会定期添加到 SB3 Contrib 仓库
- 更快的变体在 SBX (SB3 + Jax) 仓库中开发
- SB3 的训练框架 RL Zoo 有一个活跃的路线图
从 SB2 迁移到 SB3 的指南可以在文档中找到。
文档在线提供:https://stable-baselines3.readthedocs.io/
Stable-Baselines3 与其他库/服务有一些集成,例如用于实验跟踪的 Weights & Biases 或用于存储/共享训练模型的 Hugging Face。你可以在文档的专门章节中找到更多信息。
RL Baselines3 Zoo 是一个强化学习 (RL) 的训练框架。
它提供了用于训练、评估智能体、调整超参数、绘制结果和录制视频的脚本。
此外,它还包括针对常见环境和 RL 算法调优的超参数集合,以及使用这些设置训练的智能体。
该仓库的目标:
Github 仓库:https://github.com/DLR-RM/rl-baselines3-zoo
文档:https://rl-baselines3-zoo.readthedocs.io/en/master/
我们在一个单独的贡献仓库中实现实验性功能:SB3-Contrib
这使得 SB3 能够保持稳定和紧凑的核心,同时仍然提供最新的功能,如循环 PPO (PPO LSTM)、CrossQ、截断分位数评论家 (TQC)、分位数回归 DQN (QR-DQN) 或带无效动作掩码的 PPO (Maskable PPO)。
文档在线提供:https://sb3-contrib.readthedocs.io/
Stable Baselines Jax (SBX) 是 Stable-Baselines3 在 Jax 中的概念验证版本,包含 DroQ 或 CrossQ 等最新算法。
与 SB3 相比,它提供的功能较少,但速度要快得多(最高可达 20 倍!):https://twitter.com/araffin2/status/1590714558628253698
注意: Stable-Baselines3 支持 PyTorch >= 2.3
Stable Baselines3 需要 Python 3.10+。
要在 Windows 上安装 stable-baselines,请查看文档。
安装 Stable Baselines3 包:
pip install 'stable-baselines3[extra]'
这包括可选的依赖项,如 Tensorboard、OpenCV 或 ale-py 用于在 Atari 游戏上训练。如果不需要这些,可以使用:
pip install stable-baselines3
请阅读文档了解更多细节和替代方案(从源代码安装、使用 docker 等)。
库中的大部分代码都尝试遵循类似 scikit-learn 的语法来使用强化学习算法。
以下是一个如何在 cartpole 环境上训练和运行 PPO 的快速示例:
import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make("CartPole-v1", render_mode="human")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
# VecEnv 会自动重置
# if done:
# obs = env.reset()
env.close()
或者,如果环境已在 Gymnasium 中注册且策略已注册,只需一行代码即可训练模型:
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
请阅读文档获取更多示例。
以下所有示例都可以使用 Google Colab notebooks 在线执行:
| 名称 | 循环 | Box |
Discrete |
MultiDiscrete |
MultiBinary |
多进程 |
|---|---|---|---|---|---|---|
| ARS1 | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| CrossQ1 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| QR-DQN1 | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| RecurrentPPO1 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TQC1 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TRPO1 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Maskable PPO1 | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
1: 在 SB3 Contrib GitHub 仓库中实现。
动作 gymnasium.spaces:
* Box:一个 N 维盒子,包含动作空间中的每个点。
* Discrete:一个可能动作的列表,每个时间步只能使用其中一个动作。
* MultiDiscrete:一个可能动作的列表,每个时间步每个离散集合中只能使用一个动作。
* MultiBinary:一个可能动作的列表,每个时间步可以以任意组合使用任何动作。
pip install -e '.[docs,tests,extra]'
stable baselines3 中的所有单元测试都可以使用 pytest 运行器运行:
make pytest
运行单个测试文件:
python3 -m pytest -v tests/test_env_checker.py
运行单个测试:
python3 -m pytest -v -k 'test_check_env_dict_action'
你也可以使用 mypy 进行静态类型检查:
pip install mypy
make type
使用 ruff 进行代码风格检查:
pip install ruff
make lint
我们尝试在文档中维护一个使用 stable-baselines3 的项目列表,如果你想你的项目出现在这个页面上,请告诉我们 ;)
在出版物中引用此仓库:
@article{stable-baselines3,
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
journal = {Journal of Machine Learning Research},
year = {2021},
volume = {22},
number = {268},
pages = {1-8},
url = {http://jmlr.org/papers/v22/20-1364.html}
}
注意:如果需要引用 SB3 的特定版本,也可以使用 Zenodo DOI。
Stable-Baselines3 目前由 Ashley Hill (aka @hill-a)、Antonin Raffin (aka @araffin)、Maximilian Ernestus (aka @ernestum)、Adam Gleave (@AdamGleave)、Anssi Kanervisto (@Miffyli) 和 Quentin Gallouédec (@qgallouedec) 维护。
重要提示:我们不提供技术支持或咨询,也不通过电子邮件回答个人问题。
如有问题,请在 RL Discord、Reddit 或 Stack Overflow 上发帖提问。
对于任何有兴趣改进基线的人,仍有一些文档需要完成。
如果你想贡献,请先阅读 CONTRIBUTING.md 指南。