OA0
OA0 是一个探索 AI 的社区
现在注册
已注册用户请  登录
OA0  ›  代码  ›  Stable-Baselines3 基于PyTorch的强化学习算法库

Stable-Baselines3 基于PyTorch的强化学习算法库

 
  release ·  2026-03-04 14:14:06 · 5 次点击  · 0 条评论  

CI
Documentation Status coverage report
codestyle

Stable Baselines3

Stable Baselines3 (SB3) 是一组基于 PyTorch 的强化学习算法的可靠实现。它是 Stable Baselines 的下一个主要版本。

你可以在 v1.0 博客文章 或我们的 JMLR 论文 中阅读关于 Stable Baselines3 的详细介绍。

这些算法将使研究社区和行业更容易复现、改进和识别新想法,并为构建项目提供良好的基线。我们期望这些工具能作为基础,便于在其上添加新想法,并作为比较新方法与现有方法的工具。我们也希望这些工具的简洁性能让初学者能够尝试更高级的工具集,而无需深陷于实现细节。

注意:尽管易于使用,但 Stable Baselines3 (SB3) 假定你对强化学习 (RL) 有一定了解。 没有实践经验不应直接使用此库。为此,我们在文档中提供了良好的入门资源。

主要特性

每个算法的性能都经过了测试(参见各自页面中的 结果 部分),你可以查看 issue #48#49 了解更多细节。

我们还在 OpenRL Benchmark 平台上提供了详细的日志和报告。

特性 Stable-Baselines3
先进的 RL 方法 :heavy_check_mark:
文档 :heavy_check_mark:
自定义环境 :heavy_check_mark:
自定义策略 :heavy_check_mark:
通用接口 :heavy_check_mark:
支持 Dict 观测空间 :heavy_check_mark:
兼容 Ipython / Notebook :heavy_check_mark:
支持 Tensorboard :heavy_check_mark:
PEP8 代码风格 :heavy_check_mark:
自定义回调 :heavy_check_mark:
高代码覆盖率 :heavy_check_mark:
类型提示 :heavy_check_mark:

计划特性

由于原始路线图中的大部分功能已经实现,SB3 目前没有重大的变更计划,现已进入稳定阶段。
如果你想贡献,可以在 issue 中搜索标记为需要帮助的议题以及其他建议的改进

虽然 SB3 的开发现在主要集中在错误修复和维护(文档更新、用户体验等),但在相关的代码库中有更活跃的开发:
- 新算法会定期添加到 SB3 Contrib 仓库
- 更快的变体在 SBX (SB3 + Jax) 仓库中开发
- SB3 的训练框架 RL Zoo 有一个活跃的路线图

迁移指南:从 Stable-Baselines (SB2) 到 Stable-Baselines3 (SB3)

从 SB2 迁移到 SB3 的指南可以在文档中找到。

文档

文档在线提供:https://stable-baselines3.readthedocs.io/

集成

Stable-Baselines3 与其他库/服务有一些集成,例如用于实验跟踪的 Weights & Biases 或用于存储/共享训练模型的 Hugging Face。你可以在文档的专门章节中找到更多信息。

RL Baselines3 Zoo:Stable Baselines3 强化学习智能体的训练框架

RL Baselines3 Zoo 是一个强化学习 (RL) 的训练框架。

它提供了用于训练、评估智能体、调整超参数、绘制结果和录制视频的脚本。

此外,它还包括针对常见环境和 RL 算法调优的超参数集合,以及使用这些设置训练的智能体。

该仓库的目标:

  1. 提供一个简单的接口来训练和运行 RL 智能体
  2. 对不同的强化学习算法进行基准测试
  3. 为每个环境和 RL 算法提供调优的超参数
  4. 享受训练好的智能体带来的乐趣!

Github 仓库:https://github.com/DLR-RM/rl-baselines3-zoo

文档:https://rl-baselines3-zoo.readthedocs.io/en/master/

SB3-Contrib:实验性 RL 特性

我们在一个单独的贡献仓库中实现实验性功能:SB3-Contrib

这使得 SB3 能够保持稳定和紧凑的核心,同时仍然提供最新的功能,如循环 PPO (PPO LSTM)、CrossQ、截断分位数评论家 (TQC)、分位数回归 DQN (QR-DQN) 或带无效动作掩码的 PPO (Maskable PPO)。

文档在线提供:https://sb3-contrib.readthedocs.io/

Stable-Baselines Jax (SBX)

Stable Baselines Jax (SBX) 是 Stable-Baselines3 在 Jax 中的概念验证版本,包含 DroQ 或 CrossQ 等最新算法。

与 SB3 相比,它提供的功能较少,但速度要快得多(最高可达 20 倍!):https://twitter.com/araffin2/status/1590714558628253698

安装

注意: Stable-Baselines3 支持 PyTorch >= 2.3

先决条件

Stable Baselines3 需要 Python 3.10+。

Windows

要在 Windows 上安装 stable-baselines,请查看文档

使用 pip 安装

安装 Stable Baselines3 包:

pip install 'stable-baselines3[extra]'

这包括可选的依赖项,如 Tensorboard、OpenCV 或 ale-py 用于在 Atari 游戏上训练。如果不需要这些,可以使用:

pip install stable-baselines3

请阅读文档了解更多细节和替代方案(从源代码安装、使用 docker 等)。

示例

库中的大部分代码都尝试遵循类似 scikit-learn 的语法来使用强化学习算法。

以下是一个如何在 cartpole 环境上训练和运行 PPO 的快速示例:

import gymnasium as gym

from stable_baselines3 import PPO

env = gym.make("CartPole-v1", render_mode="human")

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)

vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)
    vec_env.render()
    # VecEnv 会自动重置
    # if done:
    #   obs = env.reset()

env.close()

或者,如果环境已在 Gymnasium 中注册策略已注册,只需一行代码即可训练模型:

from stable_baselines3 import PPO

model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)

请阅读文档获取更多示例。

使用 Colab Notebooks 在线尝试!

以下所有示例都可以使用 Google Colab notebooks 在线执行:

已实现的算法

名称 循环 Box Discrete MultiDiscrete MultiBinary 多进程
ARS1 :x: :heavy_check_mark: :heavy_check_mark: :x: :x: :heavy_check_mark:
A2C :x: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CrossQ1 :x: :heavy_check_mark: :x: :x: :x: :heavy_check_mark:
DDPG :x: :heavy_check_mark: :x: :x: :x: :heavy_check_mark:
DQN :x: :x: :heavy_check_mark: :x: :x: :heavy_check_mark:
HER :x: :heavy_check_mark: :heavy_check_mark: :x: :x: :heavy_check_mark:
PPO :x: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
QR-DQN1 :x: :x: :heavy_check_mark: :x: :x: :heavy_check_mark:
RecurrentPPO1 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
SAC :x: :heavy_check_mark: :x: :x: :x: :heavy_check_mark:
TD3 :x: :heavy_check_mark: :x: :x: :x: :heavy_check_mark:
TQC1 :x: :heavy_check_mark: :x: :x: :x: :heavy_check_mark:
TRPO1 :x: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
Maskable PPO1 :x: :x: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:

1: 在 SB3 Contrib GitHub 仓库中实现。

动作 gymnasium.spaces
* Box:一个 N 维盒子,包含动作空间中的每个点。
* Discrete:一个可能动作的列表,每个时间步只能使用其中一个动作。
* MultiDiscrete:一个可能动作的列表,每个时间步每个离散集合中只能使用一个动作。
* MultiBinary:一个可能动作的列表,每个时间步可以以任意组合使用任何动作。

测试安装

安装依赖

pip install -e '.[docs,tests,extra]'

运行测试

stable baselines3 中的所有单元测试都可以使用 pytest 运行器运行:

make pytest

运行单个测试文件:

python3 -m pytest -v tests/test_env_checker.py

运行单个测试:

python3 -m pytest -v -k 'test_check_env_dict_action'

你也可以使用 mypy 进行静态类型检查:

pip install mypy
make type

使用 ruff 进行代码风格检查:

pip install ruff
make lint

使用 Stable-Baselines3 的项目

我们尝试在文档中维护一个使用 stable-baselines3 的项目列表,如果你想你的项目出现在这个页面上,请告诉我们 ;)

引用本项目

在出版物中引用此仓库:

@article{stable-baselines3,
  author  = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
  title   = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
  journal = {Journal of Machine Learning Research},
  year    = {2021},
  volume  = {22},
  number  = {268},
  pages   = {1-8},
  url     = {http://jmlr.org/papers/v22/20-1364.html}
}

注意:如果需要引用 SB3 的特定版本,也可以使用 Zenodo DOI

维护者

Stable-Baselines3 目前由 Ashley Hill (aka @hill-a)、Antonin Raffin (aka @araffin)、Maximilian Ernestus (aka @ernestum)、Adam Gleave (@AdamGleave)、Anssi Kanervisto (@Miffyli) 和 Quentin Gallouédec (@qgallouedec) 维护。

重要提示:我们不提供技术支持或咨询,也不通过电子邮件回答个人问题。
如有问题,请在 RL DiscordRedditStack Overflow 上发帖提问。

如何贡献

对于任何有兴趣改进基线的人,仍有一些文档需要完成。
如果你想贡献,请先阅读 CONTRIBUTING.md 指南。

致谢

5 次点击  ∙  0 人收藏  
登录后收藏  
0 条回复
关于 ·  帮助 ·  PING ·  隐私 ·  条款   
OA0 - Omni AI 0 一个探索 AI 的社区
沪ICP备2024103595号-2
耗时 23 ms
Developed with Cursor