sparse-autoencoder-training

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

SAELens: Sparse Autoencoders for Mechanistic Interpretability

SAELens:用于机制可解释性的稀疏自编码器

SAELens is the primary library for training and analyzing Sparse Autoencoders (SAEs) - a technique for decomposing polysemantic neural network activations into sparse, interpretable features. Based on Anthropic's groundbreaking research on monosemanticity.
GitHub: jbloomAus/SAELens (1,100+ stars)
SAELens是训练和分析稀疏自编码器(SAEs)的核心库——这是一种将多语义神经网络激活分解为稀疏、可解释特征的技术,基于Anthropic在单语义性方面的突破性研究。
GitHubjbloomAus/SAELens(1100+星标)

The Problem: Polysemanticity & Superposition

问题:多语义性与叠加现象

Individual neurons in neural networks are polysemantic - they activate in multiple, semantically distinct contexts. This happens because models use superposition to represent more features than they have neurons, making interpretability difficult.
SAEs solve this by decomposing dense activations into sparse, monosemantic features - typically only a small number of features activate for any given input, and each feature corresponds to an interpretable concept.
神经网络中的单个神经元具有多语义性——它们会在多个语义不同的场景中激活。这是因为模型利用叠加来表示比神经元数量更多的特征,导致可解释性变差。
SAEs解决了这个问题,它将密集激活分解为稀疏的单语义特征——通常对于任意给定输入,只有少量特征会被激活,且每个特征对应一个可解释的概念。

When to Use SAELens

何时使用SAELens

Use SAELens when you need to:
  • Discover interpretable features in model activations
  • Understand what concepts a model has learned
  • Study superposition and feature geometry
  • Perform feature-based steering or ablation
  • Analyze safety-relevant features (deception, bias, harmful content)
Consider alternatives when:
  • You need basic activation analysis → Use TransformerLens directly
  • You want causal intervention experiments → Use pyvene or TransformerLens
  • You need production steering → Consider direct activation engineering
在以下场景中使用SAELens:
  • 发现模型激活中的可解释特征
  • 理解模型学习到的概念
  • 研究叠加现象和特征几何
  • 执行基于特征的引导或消融实验
  • 分析与安全相关的特征(欺骗、偏见、有害内容)
考虑替代方案的场景:
  • 需要基础激活分析 → 直接使用TransformerLens
  • 想要进行因果干预实验 → 使用pyveneTransformerLens
  • 需要生产环境中的引导功能 → 考虑直接激活工程

Installation

安装

bash
pip install sae-lens
Requirements: Python 3.10+, transformer-lens>=2.0.0
bash
pip install sae-lens
要求:Python 3.10+, transformer-lens>=2.0.0

Core Concepts

核心概念

What SAEs Learn

SAEs学习的内容

SAEs are trained to reconstruct model activations through a sparse bottleneck:
Input Activation → Encoder → Sparse Features → Decoder → Reconstructed Activation
    (d_model)       ↓        (d_sae >> d_model)    ↓         (d_model)
                 sparsity                      reconstruction
                 penalty                          loss
Loss Function:
MSE(original, reconstructed) + L1_coefficient × L1(features)
SAEs通过稀疏瓶颈层重构模型激活:
输入激活 → 编码器 → 稀疏特征 → 解码器 → 重构激活
    (d_model)       ↓        (d_sae >> d_model)    ↓         (d_model)
                 稀疏性                      重构
                 惩罚项                      损失
损失函数
MSE(原始激活, 重构激活) + L1系数 × L1(特征)

Key Validation (Anthropic Research)

关键验证(Anthropic研究)

In "Towards Monosemanticity", human evaluators found 70% of SAE features genuinely interpretable. Features discovered include:
  • DNA sequences, legal language, HTTP requests
  • Hebrew text, nutrition statements, code syntax
  • Sentiment, named entities, grammatical structures
在《Towards Monosemanticity》论文中,人类评估者发现70%的SAE特征是真正可解释的。已发现的特征包括:
  • DNA序列、法律语言、HTTP请求
  • 希伯来语文本、营养说明、代码语法
  • 情感、命名实体、语法结构

Workflow 1: Loading and Analyzing Pre-trained SAEs

工作流1:加载和分析预训练SAEs

Step-by-Step

分步指南

python
from transformer_lens import HookedTransformer
from sae_lens import SAE
python
from transformer_lens import HookedTransformer
from sae_lens import SAE

1. Load model and pre-trained SAE

1. 加载模型和预训练SAE

model = HookedTransformer.from_pretrained("gpt2-small", device="cuda") sae, cfg_dict, sparsity = SAE.from_pretrained( release="gpt2-small-res-jb", sae_id="blocks.8.hook_resid_pre", device="cuda" )
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda") sae, cfg_dict, sparsity = SAE.from_pretrained( release="gpt2-small-res-jb", sae_id="blocks.8.hook_resid_pre", device="cuda" )

2. Get model activations

2. 获取模型激活

tokens = model.to_tokens("The capital of France is Paris") _, cache = model.run_with_cache(tokens) activations = cache["resid_pre", 8] # [batch, pos, d_model]
tokens = model.to_tokens("The capital of France is Paris") _, cache = model.run_with_cache(tokens) activations = cache["resid_pre", 8] # [batch, pos, d_model]

3. Encode to SAE features

3. 编码为SAE特征

sae_features = sae.encode(activations) # [batch, pos, d_sae] print(f"Active features: {(sae_features > 0).sum()}")
sae_features = sae.encode(activations) # [batch, pos, d_sae] print(f"Active features: {(sae_features > 0).sum()}")

4. Find top features for each position

4. 查找每个位置的顶级特征

for pos in range(tokens.shape[1]): top_features = sae_features[0, pos].topk(5) token = model.to_str_tokens(tokens[0, pos:pos+1])[0] print(f"Token '{token}': features {top_features.indices.tolist()}")
for pos in range(tokens.shape[1]): top_features = sae_features[0, pos].topk(5) token = model.to_str_tokens(tokens[0, pos:pos+1])[0] print(f"Token '{token}': features {top_features.indices.tolist()}")

5. Reconstruct activations

5. 重构激活

reconstructed = sae.decode(sae_features) reconstruction_error = (activations - reconstructed).norm()
undefined
reconstructed = sae.decode(sae_features) reconstruction_error = (activations - reconstructed).norm()
undefined

Available Pre-trained SAEs

可用的预训练SAEs

ReleaseModelLayers
gpt2-small-res-jb
GPT-2 SmallMultiple residual streams
gemma-2b-res
Gemma 2BResidual streams
Various on HuggingFaceSearch tag
saelens
Various
版本模型层数
gpt2-small-res-jb
GPT-2 Small多个残差流
gemma-2b-res
Gemma 2B残差流
HuggingFace上的各类版本搜索标签
saelens
多种

Checklist

检查清单

  • Load model with TransformerLens
  • Load matching SAE for target layer
  • Encode activations to sparse features
  • Identify top-activating features per token
  • Validate reconstruction quality
  • 使用TransformerLens加载模型
  • 为目标层加载匹配的SAE
  • 将激活编码为稀疏特征
  • 识别每个token的顶级激活特征
  • 验证重构质量

Workflow 2: Training a Custom SAE

工作流2:训练自定义SAE

Step-by-Step

分步指南

python
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
python
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner

1. Configure training

1. 配置训练

cfg = LanguageModelSAERunnerConfig( # Model model_name="gpt2-small", hook_name="blocks.8.hook_resid_pre", hook_layer=8, d_in=768, # Model dimension
# SAE architecture
architecture="standard",  # or "gated", "topk"
d_sae=768 * 8,  # Expansion factor of 8
activation_fn="relu",

# Training
lr=4e-4,
l1_coefficient=8e-5,  # Sparsity penalty
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,

# Data
dataset_path="monology/pile-uncopyrighted",
context_size=128,

# Logging
log_to_wandb=True,
wandb_project="sae-training",

# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
)
cfg = LanguageModelSAERunnerConfig( # 模型 model_name="gpt2-small", hook_name="blocks.8.hook_resid_pre", hook_layer=8, d_in=768, # 模型维度
# SAE架构
architecture="standard",  # 或 "gated", "topk"
d_sae=768 * 8,  # 扩展因子为8
activation_fn="relu",

# 训练
lr=4e-4,
l1_coefficient=8e-5,  # 稀疏性惩罚
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,

# 数据
dataset_path="monology/pile-uncopyrighted",
context_size=128,

# 日志
log_to_wandb=True,
wandb_project="sae-training",

# 检查点
checkpoint_path="checkpoints",
n_checkpoints=5,
)

2. Train

2. 开始训练

trainer = SAETrainingRunner(cfg) sae = trainer.run()
trainer = SAETrainingRunner(cfg) sae = trainer.run()

3. Evaluate

3. 评估

print(f"L0 (avg active features): {trainer.metrics['l0']}") print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
undefined
print(f"L0 (avg active features): {trainer.metrics['l0']}") print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
undefined

Key Hyperparameters

关键超参数

ParameterTypical ValueEffect
d_sae
4-16× d_modelMore features, higher capacity
l1_coefficient
5e-5 to 1e-4Higher = sparser, less accurate
lr
1e-4 to 1e-3Standard optimizer LR
l1_warm_up_steps
500-2000Prevents early feature death
参数典型值作用
d_sae
4-16× d_model特征数量越多,容量越高
l1_coefficient
5e-5 至 1e-4值越高,稀疏性越强,准确性越低
lr
1e-4 至 1e-3标准优化器学习率
l1_warm_up_steps
500-2000防止早期特征死亡

Evaluation Metrics

评估指标

MetricTargetMeaning
L050-200Average active features per token
CE Loss Score80-95%Cross-entropy recovered vs original
Dead Features<5%Features that never activate
Explained Variance>90%Reconstruction quality
指标目标值含义
L050-200每个token的平均激活特征数
CE Loss Score80-95%与原始模型相比的交叉熵恢复率
Dead Features<5%从未激活的特征占比
Explained Variance>90%重构质量

Checklist

检查清单

  • Choose target layer and hook point
  • Set expansion factor (d_sae = 4-16× d_model)
  • Tune L1 coefficient for desired sparsity
  • Enable L1 warm-up to prevent dead features
  • Monitor metrics during training (W&B)
  • Validate L0 and CE loss recovery
  • Check dead feature ratio
  • 选择目标层和钩子点
  • 设置扩展因子(d_sae = 4-16× d_model)
  • 调整L1系数以获得所需稀疏性
  • 启用L1预热以防止特征死亡
  • 训练期间监控指标(W&B)
  • 验证L0和CE损失恢复情况
  • 检查死特征比例

Workflow 3: Feature Analysis and Steering

工作流3:特征分析与引导

Analyzing Individual Features

分析单个特征

python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch

model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.8.hook_resid_pre",
    device="cuda"
)
python
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch

model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.8.hook_resid_pre",
    device="cuda"
)

Find what activates a specific feature

查找激活特定特征的文本

feature_idx = 1234 test_texts = [ "The scientist conducted an experiment", "I love chocolate cake", "The code compiles successfully", "Paris is beautiful in spring", ]
for text in test_texts: tokens = model.to_tokens(text) _, cache = model.run_with_cache(tokens) features = sae.encode(cache["resid_pre", 8]) activation = features[0, :, feature_idx].max().item() print(f"{activation:.3f}: {text}")
undefined
feature_idx = 1234 test_texts = [ "The scientist conducted an experiment", "I love chocolate cake", "The code compiles successfully", "Paris is beautiful in spring", ]
for text in test_texts: tokens = model.to_tokens(text) _, cache = model.run_with_cache(tokens) features = sae.encode(cache["resid_pre", 8]) activation = features[0, :, feature_idx].max().item() print(f"{activation:.3f}: {text}")
undefined

Feature Steering

特征引导

python
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
    """Add SAE feature direction to residual stream."""
    tokens = model.to_tokens(prompt)

    # Get feature direction from decoder
    feature_direction = sae.W_dec[feature_idx]  # [d_model]

    def steering_hook(activation, hook):
        # Add scaled feature direction at all positions
        activation += strength * feature_direction
        return activation

    # Generate with steering
    output = model.generate(
        tokens,
        max_new_tokens=50,
        fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
    )
    return model.to_string(output[0])
python
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
    """向残差流添加SAE特征方向。"""
    tokens = model.to_tokens(prompt)

    # 从解码器获取特征方向
    feature_direction = sae.W_dec[feature_idx]  # [d_model]

    def steering_hook(activation, hook):
        # 在所有位置添加缩放后的特征方向
        activation += strength * feature_direction
        return activation

    # 带引导的生成
    output = model.generate(
        tokens,
        max_new_tokens=50,
        fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
    )
    return model.to_string(output[0])

Feature Attribution

特征归因

python
undefined
python
undefined

Which features most affect a specific output?

哪些特征对特定输出影响最大?

tokens = model.to_tokens("The capital of France is") _, cache = model.run_with_cache(tokens)
tokens = model.to_tokens("The capital of France is") _, cache = model.run_with_cache(tokens)

Get features at final position

获取最后位置的特征

features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]

Get logit attribution per feature

获取每个特征的logit归因

Feature contribution = feature_activation × decoder_weight × unembedding

特征贡献 = 特征激活 × 解码器权重 × 反嵌入

W_dec = sae.W_dec # [d_sae, d_model] W_U = model.W_U # [d_model, vocab]
W_dec = sae.W_dec # [d_sae, d_model] W_U = model.W_U # [d_model, vocab]

Contribution to "Paris" logit

对“Paris”logit的贡献

paris_token = model.to_single_token(" Paris") feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10) print("Top features for 'Paris' prediction:") for idx, val in zip(top_features.indices, top_features.values): print(f" Feature {idx.item()}: {val.item():.3f}")
undefined
paris_token = model.to_single_token(" Paris") feature_contributions = features * (W_dec @ W_U[:, paris_token])
top_features = feature_contributions.topk(10) print("Top features for 'Paris' prediction:") for idx, val in zip(top_features.indices, top_features.values): print(f" Feature {idx.item()}: {val.item():.3f}")
undefined

Common Issues & Solutions

常见问题与解决方案

Issue: High dead feature ratio

问题:死特征比例过高

python
undefined
python
undefined

WRONG: No warm-up, features die early

错误:无预热,特征过早死亡

cfg = LanguageModelSAERunnerConfig( l1_coefficient=1e-4, l1_warm_up_steps=0, # Bad! )
cfg = LanguageModelSAERunnerConfig( l1_coefficient=1e-4, l1_warm_up_steps=0, # 错误! )

RIGHT: Warm-up L1 penalty

正确:预热L1惩罚

cfg = LanguageModelSAERunnerConfig( l1_coefficient=8e-5, l1_warm_up_steps=1000, # Gradually increase use_ghost_grads=True, # Revive dead features )
undefined
cfg = LanguageModelSAERunnerConfig( l1_coefficient=8e-5, l1_warm_up_steps=1000, # 逐渐增加 use_ghost_grads=True, # 复活死特征 )
undefined

Issue: Poor reconstruction (low CE recovery)

问题:重构效果差(CE恢复率低)

python
undefined
python
undefined

Reduce sparsity penalty

降低稀疏性惩罚

cfg = LanguageModelSAERunnerConfig( l1_coefficient=5e-5, # Lower = better reconstruction d_sae=768 * 16, # More capacity )
undefined
cfg = LanguageModelSAERunnerConfig( l1_coefficient=5e-5, # 值越低,重构效果越好 d_sae=768 * 16, # 更大容量 )
undefined

Issue: Features not interpretable

问题:特征不可解释

python
undefined
python
undefined

Increase sparsity (higher L1)

增加稀疏性(更高的L1)

cfg = LanguageModelSAERunnerConfig( l1_coefficient=1e-4, # Higher = sparser, more interpretable )
cfg = LanguageModelSAERunnerConfig( l1_coefficient=1e-4, # 值越高,稀疏性越强,可解释性越好 )

Or use TopK architecture

或使用TopK架构

cfg = LanguageModelSAERunnerConfig( architecture="topk", activation_fn_kwargs={"k": 50}, # Exactly 50 active features )
undefined
cfg = LanguageModelSAERunnerConfig( architecture="topk", activation_fn_kwargs={"k": 50}, # 恰好50个激活特征 )
undefined

Issue: Memory errors during training

问题:训练时内存错误

python
cfg = LanguageModelSAERunnerConfig(
    train_batch_size_tokens=2048,  # Reduce batch size
    store_batch_size_prompts=4,    # Fewer prompts in buffer
    n_batches_in_buffer=8,         # Smaller activation buffer
)
python
cfg = LanguageModelSAERunnerConfig(
    train_batch_size_tokens=2048,  # 减小批大小
    store_batch_size_prompts=4,    # 减少缓冲区中的提示数
    n_batches_in_buffer=8,         # 更小的激活缓冲区
)

Integration with Neuronpedia

与Neuronpedia集成

Browse pre-trained SAE features at neuronpedia.org:
python
undefined
neuronpedia.org浏览预训练SAE特征:
python
undefined

Features are indexed by SAE ID

特征按SAE ID索引

Example: gpt2-small layer 8 feature 1234

示例:gpt2-small layer 8 feature 1234

→ neuronpedia.org/gpt2-small/8-res-jb/1234

→ neuronpedia.org/gpt2-small/8-res-jb/1234

undefined
undefined

Key Classes Reference

核心类参考

ClassPurpose
SAE
Sparse Autoencoder model
LanguageModelSAERunnerConfig
Training configuration
SAETrainingRunner
Training loop manager
ActivationsStore
Activation collection and batching
HookedSAETransformer
TransformerLens + SAE integration
用途
SAE
稀疏自编码器模型
LanguageModelSAERunnerConfig
训练配置
SAETrainingRunner
训练循环管理器
ActivationsStore
激活收集与批处理
HookedSAETransformer
TransformerLens与SAE集成类

Reference Documentation

参考文档

For detailed API documentation, tutorials, and advanced usage, see the
references/
folder:
FileContents
references/README.mdOverview and quick start guide
references/api.mdComplete API reference for SAE, TrainingSAE, configurations
references/tutorials.mdStep-by-step tutorials for training, analysis, steering
如需详细API文档、教程和高级用法,请查看
references/
文件夹:
文件内容
references/README.md概述与快速入门指南
references/api.mdSAE、TrainingSAE、配置的完整API参考
references/tutorials.md训练、分析、引导的分步教程

External Resources

外部资源

Tutorials

教程

Papers

论文

Official Documentation

官方文档

SAE Architectures

SAE架构

ArchitectureDescriptionUse Case
StandardReLU + L1 penaltyGeneral purpose
GatedLearned gating mechanismBetter sparsity control
TopKExactly K active featuresConsistent sparsity
python
undefined
架构描述适用场景
StandardReLU + L1惩罚通用场景
Gated学习门控机制更好的稀疏性控制
TopK恰好K个激活特征一致的稀疏性
python
undefined

TopK SAE (exactly 50 features active)

TopK SAE(恰好50个激活特征)

cfg = LanguageModelSAERunnerConfig( architecture="topk", activation_fn="topk", activation_fn_kwargs={"k": 50}, )
undefined
cfg = LanguageModelSAERunnerConfig( architecture="topk", activation_fn="topk", activation_fn_kwargs={"k": 50}, )
undefined