sparse-autoencoder-training
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseSAELens: Sparse Autoencoders for Mechanistic Interpretability
SAELens:用于机制可解释性的稀疏自编码器
SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
GitHub: jbloomAus/SAELens (1,100+ stars)
SAELens是训练和分析稀疏自编码器(SAEs)的核心库——这是一种将多语义神经网络激活分解为稀疏、可解释特征的技术,基于Anthropic在单语义性方面的突破性研究。
GitHub:jbloomAus/SAELens(1100+星标)
The Problem: Polysemanticity & Superposition
问题:多语义性与叠加现象
Individual neurons in neural networks are polysemantic - they activate in multiple, semantically distinct contexts. This happens because models use superposition to represent more features than they have neurons, making interpretability difficult.
SAEs solve this by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
神经网络中的单个神经元具有多语义性——它们会在多个语义不同的场景中激活。这是因为模型利用叠加来表示比神经元数量更多的特征,导致可解释性变差。
SAEs解决了这个问题,它将密集激活分解为稀疏的单语义特征——通常对于任意给定输入,只有少量特征会被激活,且每个特征对应一个可解释的概念。
When to Use SAELens
何时使用SAELens
Use SAELens when you need to:
- Discover interpretable features in model activations
- Understand what concepts a model has learned
- Study superposition and feature geometry
- Perform feature-based steering or ablation
- Analyze safety-relevant features (deception, bias, harmful content)
Consider alternatives when:
- You need basic activation analysis → Use TransformerLens directly
- You want causal intervention experiments → Use pyvene or TransformerLens
- You need production steering → Consider direct activation engineering
在以下场景中使用SAELens:
- 发现模型激活中的可解释特征
- 理解模型学习到的概念
- 研究叠加现象和特征几何
- 执行基于特征的引导或消融实验
- 分析与安全相关的特征(欺骗、偏见、有害内容)
考虑替代方案的场景:
- 需要基础激活分析 → 直接使用TransformerLens
- 想要进行因果干预实验 → 使用pyvene或TransformerLens
- 需要生产环境中的引导功能 → 考虑直接激活工程
Installation
安装
bash
pip install sae-lensRequirements: Python 3.10+, transformer-lens>=2.0.0
bash
pip install sae-lens要求:Python 3.10+, transformer-lens>=2.0.0
Core Concepts
核心概念
What SAEs Learn
SAEs学习的内容
SAEs are trained to reconstruct model activations through a sparse bottleneck:
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
sparsity reconstruction
penalty lossLoss Function:
MSE(original, reconstructed) + L1_coefficient × L1(features)SAEs通过稀疏瓶颈层重构模型激活:
输入激活 → 编码器 → 稀疏特征 → 解码器 → 重构激活
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
稀疏性 重构
惩罚项 损失损失函数:
MSE(原始激活, 重构激活) + L1系数 × L1(特征)Key Validation (Anthropic Research)
关键验证(Anthropic研究)
In "Towards Monosemanticity", human evaluators found 70% of SAE features genuinely interpretable. Features discovered include:
- DNA sequences, legal language, HTTP requests
- Hebrew text, nutrition statements, code syntax
- Sentiment, named entities, grammatical structures
在《Towards Monosemanticity》论文中,人类评估者发现70%的SAE特征是真正可解释的。已发现的特征包括:
- DNA序列、法律语言、HTTP请求
- 希伯来语文本、营养说明、代码语法
- 情感、命名实体、语法结构
Workflow 1: Loading and Analyzing Pre-trained SAEs
工作流1:加载和分析预训练SAEs
Step-by-Step
分步指南
python
from transformer_lens import HookedTransformer
from sae_lens import SAEpython
from transformer_lens import HookedTransformer
from sae_lens import SAE1. Load model and pre-trained SAE
1. 加载模型和预训练SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
2. Get model activations
2. 获取模型激活
tokens = model.to_tokens("The capital of France is Paris")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [batch, pos, d_model]
tokens = model.to_tokens("The capital of France is Paris")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [batch, pos, d_model]
3. Encode to SAE features
3. 编码为SAE特征
sae_features = sae.encode(activations) # [batch, pos, d_sae]
print(f"Active features: {(sae_features > 0).sum()}")
sae_features = sae.encode(activations) # [batch, pos, d_sae]
print(f"Active features: {(sae_features > 0).sum()}")
4. Find top features for each position
4. 查找每个位置的顶级特征
for pos in range(tokens.shape[1]):
top_features = sae_features[0, pos].topk(5)
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
print(f"Token '{token}': features {top_features.indices.tolist()}")
for pos in range(tokens.shape[1]):
top_features = sae_features[0, pos].topk(5)
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
print(f"Token '{token}': features {top_features.indices.tolist()}")
5. Reconstruct activations
5. 重构激活
reconstructed = sae.decode(sae_features)
reconstruction_error = (activations - reconstructed).norm()
undefinedreconstructed = sae.decode(sae_features)
reconstruction_error = (activations - reconstructed).norm()
undefinedAvailable Pre-trained SAEs
可用的预训练SAEs
| Release | Model | Layers |
|---|---|---|
| GPT-2 Small | Multiple residual streams |
| Gemma 2B | Residual streams |
| Various on HuggingFace | Search tag | Various |
| 版本 | 模型 | 层数 |
|---|---|---|
| GPT-2 Small | 多个残差流 |
| Gemma 2B | 残差流 |
| HuggingFace上的各类版本 | 搜索标签 | 多种 |
Checklist
检查清单
- Load model with TransformerLens
- Load matching SAE for target layer
- Encode activations to sparse features
- Identify top-activating features per token
- Validate reconstruction quality
- 使用TransformerLens加载模型
- 为目标层加载匹配的SAE
- 将激活编码为稀疏特征
- 识别每个token的顶级激活特征
- 验证重构质量
Workflow 2: Training a Custom SAE
工作流2:训练自定义SAE
Step-by-Step
分步指南
python
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunnerpython
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner1. Configure training
1. 配置训练
cfg = LanguageModelSAERunnerConfig(
# Model
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768, # Model dimension
# SAE architecture
architecture="standard", # or "gated", "topk"
d_sae=768 * 8, # Expansion factor of 8
activation_fn="relu",
# Training
lr=4e-4,
l1_coefficient=8e-5, # Sparsity penalty
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,
# Data
dataset_path="monology/pile-uncopyrighted",
context_size=128,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,)
cfg = LanguageModelSAERunnerConfig(
# 模型
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768, # 模型维度
# SAE架构
architecture="standard", # 或 "gated", "topk"
d_sae=768 * 8, # 扩展因子为8
activation_fn="relu",
# 训练
lr=4e-4,
l1_coefficient=8e-5, # 稀疏性惩罚
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,
# 数据
dataset_path="monology/pile-uncopyrighted",
context_size=128,
# 日志
log_to_wandb=True,
wandb_project="sae-training",
# 检查点
checkpoint_path="checkpoints",
n_checkpoints=5,)
2. Train
2. 开始训练
trainer = SAETrainingRunner(cfg)
sae = trainer.run()
trainer = SAETrainingRunner(cfg)
sae = trainer.run()
3. Evaluate
3. 评估
print(f"L0 (avg active features): {trainer.metrics['l0']}")
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
undefinedprint(f"L0 (avg active features): {trainer.metrics['l0']}")
print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
undefinedKey Hyperparameters
关键超参数
| Parameter | Typical Value | Effect |
|---|---|---|
| 4-16× d_model | More features, higher capacity |
| 5e-5 to 1e-4 | Higher = sparser, less accurate |
| 1e-4 to 1e-3 | Standard optimizer LR |
| 500-2000 | Prevents early feature death |
| 参数 | 典型值 | 作用 |
|---|---|---|
| 4-16× d_model | 特征数量越多,容量越高 |
| 5e-5 至 1e-4 | 值越高,稀疏性越强,准确性越低 |
| 1e-4 至 1e-3 | 标准优化器学习率 |
| 500-2000 | 防止早期特征死亡 |
Evaluation Metrics
评估指标
| Metric | Target | Meaning |
|---|---|---|
| L0 | 50-200 | Average active features per token |
| CE Loss Score | 80-95% | Cross-entropy recovered vs original |
| Dead Features | <5% | Features that never activate |
| Explained Variance | >90% | Reconstruction quality |
| 指标 | 目标值 | 含义 |
|---|---|---|
| L0 | 50-200 | 每个token的平均激活特征数 |
| CE Loss Score | 80-95% | 与原始模型相比的交叉熵恢复率 |
| Dead Features | <5% | 从未激活的特征占比 |
| Explained Variance | >90% | 重构质量 |
Checklist
检查清单
- Choose target layer and hook point
- Set expansion factor (d_sae = 4-16× d_model)
- Tune L1 coefficient for desired sparsity
- Enable L1 warm-up to prevent dead features
- Monitor metrics during training (W&B)
- Validate L0 and CE loss recovery
- Check dead feature ratio
- 选择目标层和钩子点
- 设置扩展因子(d_sae = 4-16× d_model)
- 调整L1系数以获得所需稀疏性
- 启用L1预热以防止特征死亡
- 训练期间监控指标(W&B)
- 验证L0和CE损失恢复情况
- 检查死特征比例
Workflow 3: Feature Analysis and Steering
工作流3:特征分析与引导
Analyzing Individual Features
分析单个特征
python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)Find what activates a specific feature
查找激活特定特征的文本
feature_idx = 1234
test_texts = [
"The scientist conducted an experiment",
"I love chocolate cake",
"The code compiles successfully",
"Paris is beautiful in spring",
]
for text in test_texts:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
features = sae.encode(cache["resid_pre", 8])
activation = features[0, :, feature_idx].max().item()
print(f"{activation:.3f}: {text}")
undefinedfeature_idx = 1234
test_texts = [
"The scientist conducted an experiment",
"I love chocolate cake",
"The code compiles successfully",
"Paris is beautiful in spring",
]
for text in test_texts:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
features = sae.encode(cache["resid_pre", 8])
activation = features[0, :, feature_idx].max().item()
print(f"{activation:.3f}: {text}")
undefinedFeature Steering
特征引导
python
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
"""Add SAE feature direction to residual stream."""
tokens = model.to_tokens(prompt)
# Get feature direction from decoder
feature_direction = sae.W_dec[feature_idx] # [d_model]
def steering_hook(activation, hook):
# Add scaled feature direction at all positions
activation += strength * feature_direction
return activation
# Generate with steering
output = model.generate(
tokens,
max_new_tokens=50,
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
)
return model.to_string(output[0])python
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
"""向残差流添加SAE特征方向。"""
tokens = model.to_tokens(prompt)
# 从解码器获取特征方向
feature_direction = sae.W_dec[feature_idx] # [d_model]
def steering_hook(activation, hook):
# 在所有位置添加缩放后的特征方向
activation += strength * feature_direction
return activation
# 带引导的生成
output = model.generate(
tokens,
max_new_tokens=50,
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
)
return model.to_string(output[0])Feature Attribution
特征归因
python
undefinedpython
undefinedWhich features most affect a specific output?
哪些特征对特定输出影响最大?
tokens = model.to_tokens("The capital of France is")
_, cache = model.run_with_cache(tokens)
tokens = model.to_tokens("The capital of France is")
_, cache = model.run_with_cache(tokens)
Get features at final position
获取最后位置的特征
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
Get logit attribution per feature
获取每个特征的logit归因
Feature contribution = feature_activation × decoder_weight × unembedding
特征贡献 = 特征激活 × 解码器权重 × 反嵌入
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, vocab]
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, vocab]
Contribution to "Paris" logit
对“Paris”logit的贡献
paris_token = model.to_single_token(" Paris")
feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10)
print("Top features for 'Paris' prediction:")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
undefinedparis_token = model.to_single_token(" Paris")
feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10)
print("Top features for 'Paris' prediction:")
for idx, val in zip(top_features.indices, top_features.values):
print(f" Feature {idx.item()}: {val.item():.3f}")
undefinedCommon Issues & Solutions
常见问题与解决方案
Issue: High dead feature ratio
问题:死特征比例过高
python
undefinedpython
undefinedWRONG: No warm-up, features die early
错误:无预热,特征过早死亡
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4,
l1_warm_up_steps=0, # Bad!
)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4,
l1_warm_up_steps=0, # 错误!
)
RIGHT: Warm-up L1 penalty
正确:预热L1惩罚
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=8e-5,
l1_warm_up_steps=1000, # Gradually increase
use_ghost_grads=True, # Revive dead features
)
undefinedcfg = LanguageModelSAERunnerConfig(
l1_coefficient=8e-5,
l1_warm_up_steps=1000, # 逐渐增加
use_ghost_grads=True, # 复活死特征
)
undefinedIssue: Poor reconstruction (low CE recovery)
问题:重构效果差(CE恢复率低)
python
undefinedpython
undefinedReduce sparsity penalty
降低稀疏性惩罚
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=5e-5, # Lower = better reconstruction
d_sae=768 * 16, # More capacity
)
undefinedcfg = LanguageModelSAERunnerConfig(
l1_coefficient=5e-5, # 值越低,重构效果越好
d_sae=768 * 16, # 更大容量
)
undefinedIssue: Features not interpretable
问题:特征不可解释
python
undefinedpython
undefinedIncrease sparsity (higher L1)
增加稀疏性(更高的L1)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4, # Higher = sparser, more interpretable
)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4, # 值越高,稀疏性越强,可解释性越好
)
Or use TopK architecture
或使用TopK架构
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
undefinedcfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 50}, # 恰好50个激活特征
)
undefinedIssue: Memory errors during training
问题:训练时内存错误
python
cfg = LanguageModelSAERunnerConfig(
train_batch_size_tokens=2048, # Reduce batch size
store_batch_size_prompts=4, # Fewer prompts in buffer
n_batches_in_buffer=8, # Smaller activation buffer
)python
cfg = LanguageModelSAERunnerConfig(
train_batch_size_tokens=2048, # 减小批大小
store_batch_size_prompts=4, # 减少缓冲区中的提示数
n_batches_in_buffer=8, # 更小的激活缓冲区
)Integration with Neuronpedia
与Neuronpedia集成
Browse pre-trained SAE features at neuronpedia.org:
python
undefined在neuronpedia.org浏览预训练SAE特征:
python
undefinedFeatures are indexed by SAE ID
特征按SAE ID索引
Example: gpt2-small layer 8 feature 1234
示例:gpt2-small layer 8 feature 1234
→ neuronpedia.org/gpt2-small/8-res-jb/1234
→ neuronpedia.org/gpt2-small/8-res-jb/1234
undefinedundefinedKey Classes Reference
核心类参考
| Class | Purpose |
|---|---|
| Sparse Autoencoder model |
| Training configuration |
| Training loop manager |
| Activation collection and batching |
| TransformerLens + SAE integration |
| 类 | 用途 |
|---|---|
| 稀疏自编码器模型 |
| 训练配置 |
| 训练循环管理器 |
| 激活收集与批处理 |
| TransformerLens与SAE集成类 |
Reference Documentation
参考文档
For detailed API documentation, tutorials, and advanced usage, see the folder:
references/| File | Contents |
|---|---|
| references/README.md | Overview and quick start guide |
| references/api.md | Complete API reference for SAE, TrainingSAE, configurations |
| references/tutorials.md | Step-by-step tutorials for training, analysis, steering |
如需详细API文档、教程和高级用法,请查看文件夹:
references/| 文件 | 内容 |
|---|---|
| references/README.md | 概述与快速入门指南 |
| references/api.md | SAE、TrainingSAE、配置的完整API参考 |
| references/tutorials.md | 训练、分析、引导的分步教程 |
External Resources
外部资源
Tutorials
教程
Papers
论文
- Towards Monosemanticity - Anthropic (2023)
- Scaling Monosemanticity - Anthropic (2024)
- Sparse Autoencoders Find Highly Interpretable Features - Cunningham et al. (ICLR 2024)
- Towards Monosemanticity - Anthropic(2023)
- Scaling Monosemanticity - Anthropic(2024)
- Sparse Autoencoders Find Highly Interpretable Features - Cunningham等人(ICLR 2024)
Official Documentation
官方文档
- SAELens Docs
- Neuronpedia - Feature browser
- SAELens Docs
- Neuronpedia - 特征浏览器
SAE Architectures
SAE架构
| Architecture | Description | Use Case |
|---|---|---|
| Standard | ReLU + L1 penalty | General purpose |
| Gated | Learned gating mechanism | Better sparsity control |
| TopK | Exactly K active features | Consistent sparsity |
python
undefined| 架构 | 描述 | 适用场景 |
|---|---|---|
| Standard | ReLU + L1惩罚 | 通用场景 |
| Gated | 学习门控机制 | 更好的稀疏性控制 |
| TopK | 恰好K个激活特征 | 一致的稀疏性 |
python
undefinedTopK SAE (exactly 50 features active)
TopK SAE(恰好50个激活特征)
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50},
)
undefinedcfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50},
)
undefined