simpo-training

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

SimPO - Simple Preference Optimization

SimPO - 简单偏好优化

Quick start

快速开始

SimPO is a reference-free preference optimization method that outperforms DPO without needing a reference model.
Installation:
bash
undefined
SimPO是一种无需参考模型的偏好优化方法,在无需参考模型的情况下性能优于DPO。
安装:
bash
undefined

Create environment

创建环境

conda create -n simpo python=3.10 && conda activate simpo
conda create -n simpo python=3.10 && conda activate simpo

Install PyTorch 2.2.2

安装PyTorch 2.2.2

Install alignment-handbook

安装alignment-handbook

git clone https://github.com/huggingface/alignment-handbook.git cd alignment-handbook python -m pip install .
git clone https://github.com/huggingface/alignment-handbook.git cd alignment-handbook python -m pip install .

Install Flash Attention 2

安装Flash Attention 2

python -m pip install flash-attn --no-build-isolation

**Training** (Mistral 7B):
```bash
ACCELERATE_LOG_LEVEL=info accelerate launch \
  --config_file accelerate_configs/deepspeed_zero3.yaml \
  scripts/run_simpo.py \
  training_configs/mistral-7b-base-simpo.yaml
python -m pip install flash-attn --no-build-isolation

**训练**(Mistral 7B):
```bash
ACCELERATE_LOG_LEVEL=info accelerate launch \
  --config_file accelerate_configs/deepspeed_zero3.yaml \
  scripts/run_simpo.py \
  training_configs/mistral-7b-base-simpo.yaml

Common workflows

常见工作流

Workflow 1: Train from base model (Mistral 7B)

工作流1:基于基础模型训练(Mistral 7B)

Config (
mistral-7b-base-simpo.yaml
):
yaml
undefined
配置 (
mistral-7b-base-simpo.yaml
):
yaml
undefined

Model

模型

model_name_or_path: mistralai/Mistral-7B-v0.1 torch_dtype: bfloat16
model_name_or_path: mistralai/Mistral-7B-v0.1 torch_dtype: bfloat16

Dataset

数据集

dataset_mixer: HuggingFaceH4/ultrafeedback_binarized: 1.0 dataset_splits:
  • train_prefs
  • test_prefs
dataset_mixer: HuggingFaceH4/ultrafeedback_binarized: 1.0 dataset_splits:
  • train_prefs
  • test_prefs

SimPO hyperparameters

SimPO超参数

beta: 2.0 # Reward scaling (2.0-10.0) gamma_beta_ratio: 0.5 # Target margin (0-1) loss_type: sigmoid # sigmoid or hinge sft_weight: 0.0 # Optional SFT regularization
beta: 2.0 # 奖励缩放(2.0-10.0) gamma_beta_ratio: 0.5 # 目标边际(0-1) loss_type: sigmoid # sigmoid或hinge sft_weight: 0.0 # 可选的SFT正则化

Training

训练设置

learning_rate: 5e-7 # Critical: 3e-7 to 1e-6 num_train_epochs: 1 per_device_train_batch_size: 1 gradient_accumulation_steps: 8
learning_rate: 5e-7 # 关键参数:3e-7至1e-6 num_train_epochs: 1 per_device_train_batch_size: 1 gradient_accumulation_steps: 8

Output

输出

output_dir: ./outputs/mistral-7b-simpo

**Launch training**:
```bash
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
  scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml
output_dir: ./outputs/mistral-7b-simpo

**启动训练**:
```bash
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
  scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml

Workflow 2: Fine-tune instruct model (Llama 3 8B)

工作流2:微调指令模型(Llama 3 8B)

Config (
llama3-8b-instruct-simpo.yaml
):
yaml
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct

dataset_mixer:
  argilla/ultrafeedback-binarized-preferences-cleaned: 1.0

beta: 2.5
gamma_beta_ratio: 0.5
learning_rate: 5e-7
sft_weight: 0.1             # Add SFT loss to preserve capabilities

num_train_epochs: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 4
output_dir: ./outputs/llama3-8b-simpo
Launch:
bash
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
  scripts/run_simpo.py training_configs/llama3-8b-instruct-simpo.yaml
配置 (
llama3-8b-instruct-simpo.yaml
):
yaml
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct

dataset_mixer:
  argilla/ultrafeedback-binarized-preferences-cleaned: 1.0

beta: 2.5
gamma_beta_ratio: 0.5
learning_rate: 5e-7
sft_weight: 0.1             # 添加SFT损失以保留模型能力

num_train_epochs: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 4
output_dir: ./outputs/llama3-8b-simpo
启动:
bash
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
  scripts/run_simpo.py training_configs/llama3-8b-instruct-simpo.yaml

Workflow 3: Reasoning-intensive tasks (lower LR)

工作流3:推理密集型任务(更低学习率)

For math/code tasks:
yaml
model_name_or_path: deepseek-ai/deepseek-math-7b-base

dataset_mixer:
  argilla/distilabel-math-preference-dpo: 1.0

beta: 5.0                   # Higher for stronger signal
gamma_beta_ratio: 0.7       # Larger margin
learning_rate: 3e-7         # Lower LR for reasoning
sft_weight: 0.0

num_train_epochs: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
针对数学/代码任务:
yaml
model_name_or_path: deepseek-ai/deepseek-math-7b-base

dataset_mixer:
  argilla/distilabel-math-preference-dpo: 1.0

beta: 5.0                   # 更高值以增强信号
gamma_beta_ratio: 0.7       # 更大边际
learning_rate: 3e-7         # 推理任务使用更低学习率
sft_weight: 0.0

num_train_epochs: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 16

When to use vs alternatives

适用场景与替代方案对比

Use SimPO when:
  • Want simpler training than DPO (no reference model)
  • Have preference data (chosen/rejected pairs)
  • Need better performance than DPO
  • Limited compute resources
  • Single-node training sufficient
Algorithm selection:
  • SimPO: Simplest, best performance, no reference model
  • DPO: Need reference model baseline, more conservative
  • PPO: Maximum control, need reward model, complex setup
  • GRPO: Memory-efficient RL, no critic
Use alternatives instead:
  • OpenRLHF: Multi-node distributed training, PPO/GRPO
  • TRL: Need multiple methods in one framework
  • DPO: Established baseline comparison
选择SimPO的场景:
  • 想要比DPO更简单的训练流程(无需参考模型)
  • 拥有偏好数据(选中/拒绝样本对)
  • 需要比DPO更优的性能
  • 计算资源有限
  • 单节点训练足够满足需求
算法选择指南:
  • SimPO: 最简单,性能最优,无需参考模型
  • DPO: 需要参考模型基线,更保守
  • PPO: 控制能力最强,需要奖励模型,配置复杂
  • GRPO: 内存高效的强化学习方法,无需批评模型
选择替代方案的场景:
  • OpenRLHF: 多节点分布式训练,支持PPO/GRPO
  • TRL: 需要在一个框架中使用多种方法
  • DPO: 作为已确立的基线对比

Common issues

常见问题

Issue: Loss divergence
Reduce learning rate:
yaml
learning_rate: 3e-7  # Reduce from 5e-7
Reduce beta:
yaml
beta: 1.0  # Reduce from 2.0
Issue: Model forgets capabilities
Add SFT regularization:
yaml
sft_weight: 0.1  # Add SFT loss component
Issue: Poor preference separation
Increase beta and margin:
yaml
beta: 5.0            # Increase from 2.0
gamma_beta_ratio: 0.8  # Increase from 0.5
Issue: OOM during training
Reduce batch size:
yaml
per_device_train_batch_size: 1
gradient_accumulation_steps: 16  # Maintain effective batch
Enable gradient checkpointing:
yaml
gradient_checkpointing: true
问题:损失发散
降低学习率:
yaml
learning_rate: 3e-7  # 从5e-7降低
降低beta值:
yaml
beta: 1.0  # 从2.0降低
问题:模型遗忘原有能力
添加SFT正则化:
yaml
sft_weight: 0.1  # 添加SFT损失组件
问题:偏好区分效果差
提高beta值和边际:
yaml
beta: 5.0            # 从2.0提高
gamma_beta_ratio: 0.8  # 从0.5提高
问题:训练时内存不足(OOM)
降低批次大小:
yaml
per_device_train_batch_size: 1
gradient_accumulation_steps: 16  # 保持有效批次大小
启用梯度检查点:
yaml
gradient_checkpointing: true

Advanced topics

进阶主题

Loss functions: See references/loss-functions.md for sigmoid vs hinge loss, mathematical formulations, and when to use each.
Hyperparameter tuning: See references/hyperparameters.md for beta, gamma, learning rate selection guide, and model-size-specific recommendations.
Dataset preparation: See references/datasets.md for preference data formats, quality filtering, and custom dataset creation.
损失函数: 参见references/loss-functions.md了解sigmoid与hinge损失的对比、数学公式及适用场景。
超参数调优: 参见references/hyperparameters.md获取beta、gamma、学习率的选择指南,以及针对不同模型尺寸的建议。
数据集准备: 参见references/datasets.md了解偏好数据格式、质量过滤及自定义数据集创建方法。

Hardware requirements

硬件要求

  • GPU: NVIDIA A100/H100 recommended
  • VRAM:
    • 7B model: 1× A100 40GB (DeepSpeed ZeRO-3)
    • 8B model: 2× A100 40GB
    • 70B model: 8× A100 80GB
  • Single-node: DeepSpeed ZeRO-3 sufficient
  • Mixed precision: BF16 recommended
Memory optimization:
  • DeepSpeed ZeRO-3 (default config)
  • Gradient checkpointing
  • Flash Attention 2
  • GPU: 推荐使用NVIDIA A100/H100
  • 显存:
    • 7B模型: 1× A100 40GB(使用DeepSpeed ZeRO-3)
    • 8B模型: 2× A100 40GB
    • 70B模型: 8× A100 80GB
  • 单节点: DeepSpeed ZeRO-3即可满足需求
  • 混合精度: 推荐使用BF16
内存优化技巧:
  • DeepSpeed ZeRO-3(默认配置)
  • 梯度检查点
  • Flash Attention 2

Resources

参考资源