stable-baselines3
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseStable Baselines3
Stable Baselines3
Overview
概述
Stable Baselines3 (SB3) is a PyTorch-based library providing reliable implementations of reinforcement learning algorithms. This skill provides comprehensive guidance for training RL agents, creating custom environments, implementing callbacks, and optimizing training workflows using SB3's unified API.
Stable Baselines3(SB3)是一个基于PyTorch的库,提供可靠的强化学习算法实现。本技能为使用SB3的统一API训练RL智能体、创建自定义环境、实现回调函数以及优化训练工作流提供全面指导。
Core Capabilities
核心功能
1. Training RL Agents
1. 训练RL智能体
Basic Training Pattern:
python
import gymnasium as gym
from stable_baselines3 import PPO基础训练模式:
python
import gymnasium as gym
from stable_baselines3 import PPOCreate environment
Create environment
env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1")
Initialize agent
Initialize agent
model = PPO("MlpPolicy", env, verbose=1)
model = PPO("MlpPolicy", env, verbose=1)
Train the agent
Train the agent
model.learn(total_timesteps=10000)
model.learn(total_timesteps=10000)
Save the model
Save the model
model.save("ppo_cartpole")
model.save("ppo_cartpole")
Load the model (without prior instantiation)
Load the model (without prior instantiation)
model = PPO.load("ppo_cartpole", env=env)
**Important Notes:**
- `total_timesteps` is a lower bound; actual training may exceed this due to batch collection
- Use `model.load()` as a static method, not on an existing instance
- The replay buffer is NOT saved with the model to save space
**Algorithm Selection:**
Use `references/algorithms.md` for detailed algorithm characteristics and selection guidance. Quick reference:
- **PPO/A2C**: General-purpose, supports all action space types, good for multiprocessing
- **SAC/TD3**: Continuous control, off-policy, sample-efficient
- **DQN**: Discrete actions, off-policy
- **HER**: Goal-conditioned tasks
See `scripts/train_rl_agent.py` for a complete training template with best practices.model = PPO.load("ppo_cartpole", env=env)
**重要说明:**
- `total_timesteps`是下限;由于批次收集,实际训练步数可能超过该值
- 将`model.load()`作为静态方法使用,不要在现有实例上调用
- 为节省空间,回放缓冲区不会随模型一起保存
**算法选择:**
参考`references/algorithms.md`获取详细的算法特性和选择指南。快速参考:
- **PPO/A2C**:通用型,支持所有动作空间类型,适用于多进程处理
- **SAC/TD3**:连续控制,离线策略,样本效率高
- **DQN**:离散动作,离线策略
- **HER**:目标条件任务
完整的最佳实践训练模板请查看`scripts/train_rl_agent.py`。2. Custom Environments
2. 自定义环境
Requirements:
Custom environments must inherit from and implement:
gymnasium.Env- : Define action_space and observation_space
__init__() - : Return initial observation and info dict
reset(seed, options) - : Return observation, reward, terminated, truncated, info
step(action) - : Visualization (optional)
render() - : Cleanup resources
close()
Key Constraints:
- Image observations must be in range [0, 255]
np.uint8 - Use channel-first format when possible (channels, height, width)
- SB3 normalizes images automatically by dividing by 255
- Set in policy_kwargs if pre-normalized
normalize_images=False - SB3 does NOT support or
Discretespaces withMultiDiscretestart!=0
Validation:
python
from stable_baselines3.common.env_checker import check_env
check_env(env, warn=True)See for a complete custom environment template and for comprehensive guidance.
scripts/custom_env_template.pyreferences/custom_environments.md要求:
自定义环境必须继承自并实现以下方法:
gymnasium.Env- :定义action_space和observation_space
__init__() - :返回初始观测值和信息字典
reset(seed, options) - :返回观测值、奖励、终止状态、截断状态和信息
step(action) - :可视化(可选)
render() - :清理资源
close()
关键约束:
- 图像观测值必须为类型,取值范围[0, 255]
np.uint8 - 尽可能使用通道优先格式(channels, height, width)
- SB3会自动将图像除以255进行归一化
- 如果已提前归一化,在policy_kwargs中设置
normalize_images=False - SB3不支持的
start!=0或Discrete空间MultiDiscrete
验证:
python
from stable_baselines3.common.env_checker import check_env
check_env(env, warn=True)完整的自定义环境模板请查看,全面指导请查看。
scripts/custom_env_template.pyreferences/custom_environments.md3. Vectorized Environments
3. 向量化环境
Purpose:
Vectorized environments run multiple environment instances in parallel, accelerating training and enabling certain wrappers (frame-stacking, normalization).
Types:
- DummyVecEnv: Sequential execution on current process (for lightweight environments)
- SubprocVecEnv: Parallel execution across processes (for compute-heavy environments)
Quick Setup:
python
from stable_baselines3.common.env_util import make_vec_env用途:
向量化环境可并行运行多个环境实例,加速训练并支持某些包装器(如帧堆叠、归一化)。
类型:
- DummyVecEnv:在当前进程中顺序执行(适用于轻量级环境)
- SubprocVecEnv:跨进程并行执行(适用于计算密集型环境)
快速设置:
python
from stable_baselines3.common.env_util import make_vec_envCreate 4 parallel environments
Create 4 parallel environments
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
**Off-Policy Optimization:**
When using multiple environments with off-policy algorithms (SAC, TD3, DQN), set `gradient_steps=-1` to perform one gradient update per environment step, balancing wall-clock time and sample efficiency.
**API Differences:**
- `reset()` returns only observations (info available in `vec_env.reset_infos`)
- `step()` returns 4-tuple: `(obs, rewards, dones, infos)` not 5-tuple
- Environments auto-reset after episodes
- Terminal observations available via `infos[env_idx]["terminal_observation"]`
See `references/vectorized_envs.md` for detailed information on wrappers and advanced usage.env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
**离线策略优化:**
当在离线策略算法(SAC、TD3、DQN)中使用多环境时,设置`gradient_steps=-1`可在每个环境步骤后执行一次梯度更新,平衡实际耗时和样本效率。
**API差异:**
- `reset()`仅返回观测值(信息可通过`vec_env.reset_infos`获取)
- `step()`返回4元组:`(obs, rewards, dones, infos)`而非5元组
- 环境在回合结束后会自动重置
- 终端观测值可通过`infos[env_idx]["terminal_observation"]`获取
包装器和高级用法的详细信息请查看`references/vectorized_envs.md`。4. Callbacks for Monitoring and Control
4. 用于监控与控制的回调函数
Purpose:
Callbacks enable monitoring metrics, saving checkpoints, implementing early stopping, and custom training logic without modifying core algorithms.
Common Callbacks:
- EvalCallback: Evaluate periodically and save best model
- CheckpointCallback: Save model checkpoints at intervals
- StopTrainingOnRewardThreshold: Stop when target reward reached
- ProgressBarCallback: Display training progress with timing
Custom Callback Structure:
python
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
def _on_training_start(self):
# Called before first rollout
pass
def _on_step(self):
# Called after each environment step
# Return False to stop training
return True
def _on_rollout_end(self):
# Called at end of rollout
passAvailable Attributes:
- : The RL algorithm instance
self.model - : Total environment steps
self.num_timesteps - : The training environment
self.training_env
Chaining Callbacks:
python
from stable_baselines3.common.callbacks import CallbackList
callback = CallbackList([eval_callback, checkpoint_callback, custom_callback])
model.learn(total_timesteps=10000, callback=callback)See for comprehensive callback documentation.
references/callbacks.md用途:
回调函数可在不修改核心算法的情况下,实现指标监控、保存检查点、提前停止训练以及自定义训练逻辑。
常用回调函数:
- EvalCallback:定期评估并保存最佳模型
- CheckpointCallback:按时间间隔保存模型检查点
- StopTrainingOnRewardThreshold:当达到目标奖励时停止训练
- ProgressBarCallback:显示带计时的训练进度
自定义回调函数结构:
python
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
def _on_training_start(self):
# Called before first rollout
pass
def _on_step(self):
# Called after each environment step
# Return False to stop training
return True
def _on_rollout_end(self):
# Called at end of rollout
pass可用属性:
- :RL算法实例
self.model - :总环境步数
self.num_timesteps - :训练环境
self.training_env
回调函数链式调用:
python
from stable_baselines3.common.callbacks import CallbackList
callback = CallbackList([eval_callback, checkpoint_callback, custom_callback])
model.learn(total_timesteps=10000, callback=callback)回调函数的全面文档请查看。
references/callbacks.md5. Model Persistence and Inspection
5. 模型持久化与检查
Saving and Loading:
python
undefined保存与加载:
python
undefinedSave model
Save model
model.save("model_name")
model.save("model_name")
Save normalization statistics (if using VecNormalize)
Save normalization statistics (if using VecNormalize)
vec_env.save("vec_normalize.pkl")
vec_env.save("vec_normalize.pkl")
Load model
Load model
model = PPO.load("model_name", env=env)
model = PPO.load("model_name", env=env)
Load normalization statistics
Load normalization statistics
vec_env = VecNormalize.load("vec_normalize.pkl", vec_env)
**Parameter Access:**
```pythonvec_env = VecNormalize.load("vec_normalize.pkl", vec_env)
**参数访问:**
```pythonGet parameters
Get parameters
params = model.get_parameters()
params = model.get_parameters()
Set parameters
Set parameters
model.set_parameters(params)
model.set_parameters(params)
Access PyTorch state dict
Access PyTorch state dict
state_dict = model.policy.state_dict()
undefinedstate_dict = model.policy.state_dict()
undefined6. Evaluation and Recording
6. 评估与录制
Evaluation:
python
from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(
model,
env,
n_eval_episodes=10,
deterministic=True
)Video Recording:
python
from stable_baselines3.common.vec_env import VecVideoRecorder评估:
python
from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(
model,
env,
n_eval_episodes=10,
deterministic=True
)视频录制:
python
from stable_baselines3.common.vec_env import VecVideoRecorderWrap environment with video recorder
Wrap environment with video recorder
env = VecVideoRecorder(
env,
"videos/",
record_video_trigger=lambda x: x % 2000 == 0,
video_length=200
)
See `scripts/evaluate_agent.py` for a complete evaluation and recording template.env = VecVideoRecorder(
env,
"videos/",
record_video_trigger=lambda x: x % 2000 == 0,
video_length=200
)
完整的评估与录制模板请查看`scripts/evaluate_agent.py`。7. Advanced Features
7. 高级功能
Learning Rate Schedules:
python
def linear_schedule(initial_value):
def func(progress_remaining):
# progress_remaining goes from 1 to 0
return progress_remaining * initial_value
return func
model = PPO("MlpPolicy", env, learning_rate=linear_schedule(0.001))Multi-Input Policies (Dict Observations):
python
model = PPO("MultiInputPolicy", env, verbose=1)Use when observations are dictionaries (e.g., combining images with sensor data).
Hindsight Experience Replay:
python
from stable_baselines3 import SAC, HerReplayBuffer
model = SAC(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=4,
goal_selection_strategy="future",
),
)TensorBoard Integration:
python
model = PPO("MlpPolicy", env, tensorboard_log="./tensorboard/")
model.learn(total_timesteps=10000)学习率调度:
python
def linear_schedule(initial_value):
def func(progress_remaining):
# progress_remaining goes from 1 to 0
return progress_remaining * initial_value
return func
model = PPO("MlpPolicy", env, learning_rate=linear_schedule(0.001))多输入策略(字典观测值):
python
model = PPO("MultiInputPolicy", env, verbose=1)适用于观测值为字典的场景(例如,将图像与传感器数据结合)。
后见之明经验回放:
python
from stable_baselines3 import SAC, HerReplayBuffer
model = SAC(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=4,
goal_selection_strategy="future",
),
)TensorBoard集成:
python
model = PPO("MlpPolicy", env, tensorboard_log="./tensorboard/")
model.learn(total_timesteps=10000)Workflow Guidance
工作流指导
Starting a New RL Project:
- Define the problem: Identify observation space, action space, and reward structure
- Choose algorithm: Use for selection guidance
references/algorithms.md - Create/adapt environment: Use if needed
scripts/custom_env_template.py - Validate environment: Always run before training
check_env() - Set up training: Use as starting template
scripts/train_rl_agent.py - Add monitoring: Implement callbacks for evaluation and checkpointing
- Optimize performance: Consider vectorized environments for speed
- Evaluate and iterate: Use for assessment
scripts/evaluate_agent.py
Common Issues:
- Memory errors: Reduce for off-policy algorithms or use fewer parallel environments
buffer_size - Slow training: Consider SubprocVecEnv for parallel environments
- Unstable training: Try different algorithms, tune hyperparameters, or check reward scaling
- Import errors: Ensure is installed:
stable_baselines3uv pip install stable-baselines3[extra]
启动新RL项目:
- 定义问题:确定观测空间、动作空间和奖励结构
- 选择算法:参考进行选择
references/algorithms.md - 创建/适配环境:需要时使用
scripts/custom_env_template.py - 验证环境:训练前务必运行
check_env() - 设置训练:以为起始模板
scripts/train_rl_agent.py - 添加监控:实现评估和检查点回调函数
- 优化性能:考虑使用向量化环境提升速度
- 评估与迭代:使用进行评估
scripts/evaluate_agent.py
常见问题:
- 内存错误:减少离线策略算法的或使用更少的并行环境
buffer_size - 训练缓慢:考虑使用SubprocVecEnv进行并行处理
- 训练不稳定:尝试不同算法、调整超参数或检查奖励缩放
- 导入错误:确保已安装:
stable_baselines3uv pip install stable-baselines3[extra]
Resources
资源
scripts/
scripts/
- : Complete training script template with best practices
train_rl_agent.py - : Agent evaluation and video recording template
evaluate_agent.py - : Custom Gym environment template
custom_env_template.py
- :包含最佳实践的完整训练脚本模板
train_rl_agent.py - :智能体评估与视频录制模板
evaluate_agent.py - :自定义Gym环境模板
custom_env_template.py
references/
references/
- : Detailed algorithm comparison and selection guide
algorithms.md - : Comprehensive custom environment creation guide
custom_environments.md - : Complete callback system reference
callbacks.md - : Vectorized environment usage and wrappers
vectorized_envs.md
- :详细的算法对比与选择指南
algorithms.md - :全面的自定义环境创建指南
custom_environments.md - :完整的回调系统参考文档
callbacks.md - :向量化环境使用与包装器指南
vectorized_envs.md
Installation
安装
bash
undefinedbash
undefinedBasic installation
Basic installation
uv pip install stable-baselines3
uv pip install stable-baselines3
With extra dependencies (Tensorboard, etc.)
With extra dependencies (Tensorboard, etc.)
uv pip install stable-baselines3[extra]
undefineduv pip install stable-baselines3[extra]
undefined