stable-baselines3

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Stable Baselines3

Stable Baselines3

Overview

概述

Stable Baselines3 (SB3) is a PyTorch-based library providing reliable implementations of reinforcement learning algorithms. This skill provides comprehensive guidance for training RL agents, creating custom environments, implementing callbacks, and optimizing training workflows using SB3's unified API.
Stable Baselines3(SB3)是一个基于PyTorch的库,提供可靠的强化学习算法实现。本技能为使用SB3的统一API训练RL智能体、创建自定义环境、实现回调函数以及优化训练工作流提供全面指导。

Core Capabilities

核心功能

1. Training RL Agents

1. 训练RL智能体

Basic Training Pattern:
python
import gymnasium as gym
from stable_baselines3 import PPO
基础训练模式:
python
import gymnasium as gym
from stable_baselines3 import PPO

Create environment

Create environment

env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1")

Initialize agent

Initialize agent

model = PPO("MlpPolicy", env, verbose=1)
model = PPO("MlpPolicy", env, verbose=1)

Train the agent

Train the agent

model.learn(total_timesteps=10000)
model.learn(total_timesteps=10000)

Save the model

Save the model

model.save("ppo_cartpole")
model.save("ppo_cartpole")

Load the model (without prior instantiation)

Load the model (without prior instantiation)

model = PPO.load("ppo_cartpole", env=env)

**Important Notes:**
- `total_timesteps` is a lower bound; actual training may exceed this due to batch collection
- Use `model.load()` as a static method, not on an existing instance
- The replay buffer is NOT saved with the model to save space

**Algorithm Selection:**
Use `references/algorithms.md` for detailed algorithm characteristics and selection guidance. Quick reference:
- **PPO/A2C**: General-purpose, supports all action space types, good for multiprocessing
- **SAC/TD3**: Continuous control, off-policy, sample-efficient
- **DQN**: Discrete actions, off-policy
- **HER**: Goal-conditioned tasks

See `scripts/train_rl_agent.py` for a complete training template with best practices.
model = PPO.load("ppo_cartpole", env=env)

**重要说明:**
- `total_timesteps`是下限;由于批次收集,实际训练步数可能超过该值
- 将`model.load()`作为静态方法使用,不要在现有实例上调用
- 为节省空间,回放缓冲区不会随模型一起保存

**算法选择:**
参考`references/algorithms.md`获取详细的算法特性和选择指南。快速参考:
- **PPO/A2C**:通用型,支持所有动作空间类型,适用于多进程处理
- **SAC/TD3**:连续控制,离线策略,样本效率高
- **DQN**:离散动作,离线策略
- **HER**:目标条件任务

完整的最佳实践训练模板请查看`scripts/train_rl_agent.py`。

2. Custom Environments

2. 自定义环境

Requirements: Custom environments must inherit from
gymnasium.Env
and implement:
  • __init__()
    : Define action_space and observation_space
  • reset(seed, options)
    : Return initial observation and info dict
  • step(action)
    : Return observation, reward, terminated, truncated, info
  • render()
    : Visualization (optional)
  • close()
    : Cleanup resources
Key Constraints:
  • Image observations must be
    np.uint8
    in range [0, 255]
  • Use channel-first format when possible (channels, height, width)
  • SB3 normalizes images automatically by dividing by 255
  • Set
    normalize_images=False
    in policy_kwargs if pre-normalized
  • SB3 does NOT support
    Discrete
    or
    MultiDiscrete
    spaces with
    start!=0
Validation:
python
from stable_baselines3.common.env_checker import check_env

check_env(env, warn=True)
See
scripts/custom_env_template.py
for a complete custom environment template and
references/custom_environments.md
for comprehensive guidance.
要求: 自定义环境必须继承自
gymnasium.Env
并实现以下方法:
  • __init__()
    :定义action_space和observation_space
  • reset(seed, options)
    :返回初始观测值和信息字典
  • step(action)
    :返回观测值、奖励、终止状态、截断状态和信息
  • render()
    :可视化(可选)
  • close()
    :清理资源
关键约束:
  • 图像观测值必须为
    np.uint8
    类型,取值范围[0, 255]
  • 尽可能使用通道优先格式(channels, height, width)
  • SB3会自动将图像除以255进行归一化
  • 如果已提前归一化,在policy_kwargs中设置
    normalize_images=False
  • SB3不支持
    start!=0
    Discrete
    MultiDiscrete
    空间
验证:
python
from stable_baselines3.common.env_checker import check_env

check_env(env, warn=True)
完整的自定义环境模板请查看
scripts/custom_env_template.py
,全面指导请查看
references/custom_environments.md

3. Vectorized Environments

3. 向量化环境

Purpose: Vectorized environments run multiple environment instances in parallel, accelerating training and enabling certain wrappers (frame-stacking, normalization).
Types:
  • DummyVecEnv: Sequential execution on current process (for lightweight environments)
  • SubprocVecEnv: Parallel execution across processes (for compute-heavy environments)
Quick Setup:
python
from stable_baselines3.common.env_util import make_vec_env
用途: 向量化环境可并行运行多个环境实例,加速训练并支持某些包装器(如帧堆叠、归一化)。
类型:
  • DummyVecEnv:在当前进程中顺序执行(适用于轻量级环境)
  • SubprocVecEnv:跨进程并行执行(适用于计算密集型环境)
快速设置:
python
from stable_baselines3.common.env_util import make_vec_env

Create 4 parallel environments

Create 4 parallel environments

env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=25000)

**Off-Policy Optimization:**
When using multiple environments with off-policy algorithms (SAC, TD3, DQN), set `gradient_steps=-1` to perform one gradient update per environment step, balancing wall-clock time and sample efficiency.

**API Differences:**
- `reset()` returns only observations (info available in `vec_env.reset_infos`)
- `step()` returns 4-tuple: `(obs, rewards, dones, infos)` not 5-tuple
- Environments auto-reset after episodes
- Terminal observations available via `infos[env_idx]["terminal_observation"]`

See `references/vectorized_envs.md` for detailed information on wrappers and advanced usage.
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=25000)

**离线策略优化:**
当在离线策略算法(SAC、TD3、DQN)中使用多环境时,设置`gradient_steps=-1`可在每个环境步骤后执行一次梯度更新,平衡实际耗时和样本效率。

**API差异:**
- `reset()`仅返回观测值(信息可通过`vec_env.reset_infos`获取)
- `step()`返回4元组:`(obs, rewards, dones, infos)`而非5元组
- 环境在回合结束后会自动重置
- 终端观测值可通过`infos[env_idx]["terminal_observation"]`获取

包装器和高级用法的详细信息请查看`references/vectorized_envs.md`。

4. Callbacks for Monitoring and Control

4. 用于监控与控制的回调函数

Purpose: Callbacks enable monitoring metrics, saving checkpoints, implementing early stopping, and custom training logic without modifying core algorithms.
Common Callbacks:
  • EvalCallback: Evaluate periodically and save best model
  • CheckpointCallback: Save model checkpoints at intervals
  • StopTrainingOnRewardThreshold: Stop when target reward reached
  • ProgressBarCallback: Display training progress with timing
Custom Callback Structure:
python
from stable_baselines3.common.callbacks import BaseCallback

class CustomCallback(BaseCallback):
    def _on_training_start(self):
        # Called before first rollout
        pass

    def _on_step(self):
        # Called after each environment step
        # Return False to stop training
        return True

    def _on_rollout_end(self):
        # Called at end of rollout
        pass
Available Attributes:
  • self.model
    : The RL algorithm instance
  • self.num_timesteps
    : Total environment steps
  • self.training_env
    : The training environment
Chaining Callbacks:
python
from stable_baselines3.common.callbacks import CallbackList

callback = CallbackList([eval_callback, checkpoint_callback, custom_callback])
model.learn(total_timesteps=10000, callback=callback)
See
references/callbacks.md
for comprehensive callback documentation.
用途: 回调函数可在不修改核心算法的情况下,实现指标监控、保存检查点、提前停止训练以及自定义训练逻辑。
常用回调函数:
  • EvalCallback:定期评估并保存最佳模型
  • CheckpointCallback:按时间间隔保存模型检查点
  • StopTrainingOnRewardThreshold:当达到目标奖励时停止训练
  • ProgressBarCallback:显示带计时的训练进度
自定义回调函数结构:
python
from stable_baselines3.common.callbacks import BaseCallback

class CustomCallback(BaseCallback):
    def _on_training_start(self):
        # Called before first rollout
        pass

    def _on_step(self):
        # Called after each environment step
        # Return False to stop training
        return True

    def _on_rollout_end(self):
        # Called at end of rollout
        pass
可用属性:
  • self.model
    :RL算法实例
  • self.num_timesteps
    :总环境步数
  • self.training_env
    :训练环境
回调函数链式调用:
python
from stable_baselines3.common.callbacks import CallbackList

callback = CallbackList([eval_callback, checkpoint_callback, custom_callback])
model.learn(total_timesteps=10000, callback=callback)
回调函数的全面文档请查看
references/callbacks.md

5. Model Persistence and Inspection

5. 模型持久化与检查

Saving and Loading:
python
undefined
保存与加载:
python
undefined

Save model

Save model

model.save("model_name")
model.save("model_name")

Save normalization statistics (if using VecNormalize)

Save normalization statistics (if using VecNormalize)

vec_env.save("vec_normalize.pkl")
vec_env.save("vec_normalize.pkl")

Load model

Load model

model = PPO.load("model_name", env=env)
model = PPO.load("model_name", env=env)

Load normalization statistics

Load normalization statistics

vec_env = VecNormalize.load("vec_normalize.pkl", vec_env)

**Parameter Access:**
```python
vec_env = VecNormalize.load("vec_normalize.pkl", vec_env)

**参数访问:**
```python

Get parameters

Get parameters

params = model.get_parameters()
params = model.get_parameters()

Set parameters

Set parameters

model.set_parameters(params)
model.set_parameters(params)

Access PyTorch state dict

Access PyTorch state dict

state_dict = model.policy.state_dict()
undefined
state_dict = model.policy.state_dict()
undefined

6. Evaluation and Recording

6. 评估与录制

Evaluation:
python
from stable_baselines3.common.evaluation import evaluate_policy

mean_reward, std_reward = evaluate_policy(
    model,
    env,
    n_eval_episodes=10,
    deterministic=True
)
Video Recording:
python
from stable_baselines3.common.vec_env import VecVideoRecorder
评估:
python
from stable_baselines3.common.evaluation import evaluate_policy

mean_reward, std_reward = evaluate_policy(
    model,
    env,
    n_eval_episodes=10,
    deterministic=True
)
视频录制:
python
from stable_baselines3.common.vec_env import VecVideoRecorder

Wrap environment with video recorder

Wrap environment with video recorder

env = VecVideoRecorder( env, "videos/", record_video_trigger=lambda x: x % 2000 == 0, video_length=200 )

See `scripts/evaluate_agent.py` for a complete evaluation and recording template.
env = VecVideoRecorder( env, "videos/", record_video_trigger=lambda x: x % 2000 == 0, video_length=200 )

完整的评估与录制模板请查看`scripts/evaluate_agent.py`。

7. Advanced Features

7. 高级功能

Learning Rate Schedules:
python
def linear_schedule(initial_value):
    def func(progress_remaining):
        # progress_remaining goes from 1 to 0
        return progress_remaining * initial_value
    return func

model = PPO("MlpPolicy", env, learning_rate=linear_schedule(0.001))
Multi-Input Policies (Dict Observations):
python
model = PPO("MultiInputPolicy", env, verbose=1)
Use when observations are dictionaries (e.g., combining images with sensor data).
Hindsight Experience Replay:
python
from stable_baselines3 import SAC, HerReplayBuffer

model = SAC(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    replay_buffer_kwargs=dict(
        n_sampled_goal=4,
        goal_selection_strategy="future",
    ),
)
TensorBoard Integration:
python
model = PPO("MlpPolicy", env, tensorboard_log="./tensorboard/")
model.learn(total_timesteps=10000)
学习率调度:
python
def linear_schedule(initial_value):
    def func(progress_remaining):
        # progress_remaining goes from 1 to 0
        return progress_remaining * initial_value
    return func

model = PPO("MlpPolicy", env, learning_rate=linear_schedule(0.001))
多输入策略(字典观测值):
python
model = PPO("MultiInputPolicy", env, verbose=1)
适用于观测值为字典的场景(例如,将图像与传感器数据结合)。
后见之明经验回放:
python
from stable_baselines3 import SAC, HerReplayBuffer

model = SAC(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    replay_buffer_kwargs=dict(
        n_sampled_goal=4,
        goal_selection_strategy="future",
    ),
)
TensorBoard集成:
python
model = PPO("MlpPolicy", env, tensorboard_log="./tensorboard/")
model.learn(total_timesteps=10000)

Workflow Guidance

工作流指导

Starting a New RL Project:
  1. Define the problem: Identify observation space, action space, and reward structure
  2. Choose algorithm: Use
    references/algorithms.md
    for selection guidance
  3. Create/adapt environment: Use
    scripts/custom_env_template.py
    if needed
  4. Validate environment: Always run
    check_env()
    before training
  5. Set up training: Use
    scripts/train_rl_agent.py
    as starting template
  6. Add monitoring: Implement callbacks for evaluation and checkpointing
  7. Optimize performance: Consider vectorized environments for speed
  8. Evaluate and iterate: Use
    scripts/evaluate_agent.py
    for assessment
Common Issues:
  • Memory errors: Reduce
    buffer_size
    for off-policy algorithms or use fewer parallel environments
  • Slow training: Consider SubprocVecEnv for parallel environments
  • Unstable training: Try different algorithms, tune hyperparameters, or check reward scaling
  • Import errors: Ensure
    stable_baselines3
    is installed:
    uv pip install stable-baselines3[extra]
启动新RL项目:
  1. 定义问题:确定观测空间、动作空间和奖励结构
  2. 选择算法:参考
    references/algorithms.md
    进行选择
  3. 创建/适配环境:需要时使用
    scripts/custom_env_template.py
  4. 验证环境:训练前务必运行
    check_env()
  5. 设置训练:以
    scripts/train_rl_agent.py
    为起始模板
  6. 添加监控:实现评估和检查点回调函数
  7. 优化性能:考虑使用向量化环境提升速度
  8. 评估与迭代:使用
    scripts/evaluate_agent.py
    进行评估
常见问题:
  • 内存错误:减少离线策略算法的
    buffer_size
    或使用更少的并行环境
  • 训练缓慢:考虑使用SubprocVecEnv进行并行处理
  • 训练不稳定:尝试不同算法、调整超参数或检查奖励缩放
  • 导入错误:确保已安装
    stable_baselines3
    uv pip install stable-baselines3[extra]

Resources

资源

scripts/

scripts/

  • train_rl_agent.py
    : Complete training script template with best practices
  • evaluate_agent.py
    : Agent evaluation and video recording template
  • custom_env_template.py
    : Custom Gym environment template
  • train_rl_agent.py
    :包含最佳实践的完整训练脚本模板
  • evaluate_agent.py
    :智能体评估与视频录制模板
  • custom_env_template.py
    :自定义Gym环境模板

references/

references/

  • algorithms.md
    : Detailed algorithm comparison and selection guide
  • custom_environments.md
    : Comprehensive custom environment creation guide
  • callbacks.md
    : Complete callback system reference
  • vectorized_envs.md
    : Vectorized environment usage and wrappers
  • algorithms.md
    :详细的算法对比与选择指南
  • custom_environments.md
    :全面的自定义环境创建指南
  • callbacks.md
    :完整的回调系统参考文档
  • vectorized_envs.md
    :向量化环境使用与包装器指南

Installation

安装

bash
undefined
bash
undefined

Basic installation

Basic installation

uv pip install stable-baselines3
uv pip install stable-baselines3

With extra dependencies (Tensorboard, etc.)

With extra dependencies (Tensorboard, etc.)

uv pip install stable-baselines3[extra]
undefined
uv pip install stable-baselines3[extra]
undefined