transformer-lens-interpretability
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseTransformerLens: Mechanistic Interpretability for Transformers
TransformerLens:面向Transformer的机械可解释性工具
TransformerLens is the de facto standard library for mechanistic interpretability research on GPT-style language models. Created by Neel Nanda and maintained by Bryce Meyer, it provides clean interfaces to inspect and manipulate model internals via HookPoints on every activation.
GitHub: TransformerLensOrg/TransformerLens (2,900+ stars)
TransformerLens是GPT类语言模型机械可解释性研究领域的事实标准库,由Neel Nanda创建,Bryce Meyer维护,它提供简洁的接口,通过在每个激活点设置HookPoints来检查和操作模型内部结构。
GitHub:TransformerLensOrg/TransformerLens(2900+星标)
When to Use TransformerLens
何时使用TransformerLens
Use TransformerLens when you need to:
- Reverse-engineer algorithms learned during training
- Perform activation patching / causal tracing experiments
- Study attention patterns and information flow
- Analyze circuits (e.g., induction heads, IOI circuit)
- Cache and inspect intermediate activations
- Apply direct logit attribution
Consider alternatives when:
- You need to work with non-transformer architectures → Use nnsight or pyvene
- You want to train/analyze Sparse Autoencoders → Use SAELens
- You need remote execution on massive models → Use nnsight with NDIF
- You want higher-level causal intervention abstractions → Use pyvene
在以下场景中使用TransformerLens:
- 逆向工程模型在训练过程中学习到的算法
- 进行激活修补/因果追踪实验
- 研究注意力模式与信息流
- 分析电路(如归纳头、IOI电路)
- 缓存并检查中间激活值
- 应用直接对数几率归因
在以下场景中考虑替代工具:
- 处理非Transformer架构 → 使用nnsight或pyvene
- 训练/分析稀疏自编码器 → 使用SAELens
- 在大规模模型上进行远程执行 → 结合NDIF使用nnsight
- 需要更高层级的因果干预抽象 → 使用pyvene
Installation
安装
bash
pip install transformer-lensFor development version:
bash
pip install git+https://github.com/TransformerLensOrg/TransformerLensbash
pip install transformer-lens开发版本安装:
bash
pip install git+https://github.com/TransformerLensOrg/TransformerLensCore Concepts
核心概念
HookedTransformer
HookedTransformer
The main class that wraps transformer models with HookPoints on every activation:
python
from transformer_lens import HookedTransformer这是核心类,用于包装Transformer模型,并在每个激活点设置HookPoints:
python
from transformer_lens import HookedTransformerLoad a model
加载模型
model = HookedTransformer.from_pretrained("gpt2-small")
model = HookedTransformer.from_pretrained("gpt2-small")
For gated models (LLaMA, Mistral)
加载 gated 模型(如LLaMA、Mistral)
import os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
undefinedimport os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
undefinedSupported Models (50+)
支持的模型(50+种)
| Family | Models |
|---|---|
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
| LLaMA | llama-7b, llama-13b, llama-2-7b, llama-2-13b |
| EleutherAI | pythia-70m to pythia-12b, gpt-neo, gpt-j-6b |
| Mistral | mistral-7b, mixtral-8x7b |
| Others | phi, qwen, opt, gemma |
| 模型家族 | 具体模型 |
|---|---|
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
| LLaMA | llama-7b, llama-13b, llama-2-7b, llama-2-13b |
| EleutherAI | pythia-70m to pythia-12b, gpt-neo, gpt-j-6b |
| Mistral | mistral-7b, mixtral-8x7b |
| 其他 | phi, qwen, opt, gemma |
Activation Caching
激活缓存(Activation Caching)
Run the model and cache all intermediate activations:
python
undefined运行模型并缓存所有中间激活值:
python
undefinedGet all activations
获取所有激活值
tokens = model.to_tokens("The Eiffel Tower is in")
logits, cache = model.run_with_cache(tokens)
tokens = model.to_tokens("The Eiffel Tower is in")
logits, cache = model.run_with_cache(tokens)
Access specific activations
访问特定激活值
residual = cache["resid_post", 5] # Layer 5 residual stream
attn_pattern = cache["pattern", 3] # Layer 3 attention pattern
mlp_out = cache["mlp_out", 7] # Layer 7 MLP output
residual = cache["resid_post", 5] # 第5层的残差流
attn_pattern = cache["pattern", 3] # 第3层的注意力模式
mlp_out = cache["mlp_out", 7] # 第7层的MLP输出
Filter which activations to cache (saves memory)
过滤需要缓存的激活值(节省内存)
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda name: "resid_post" in name
)
undefinedlogits, cache = model.run_with_cache(
tokens,
names_filter=lambda name: "resid_post" in name
)
undefinedActivationCache Keys
ActivationCache 键值说明
| Key Pattern | Shape | Description |
|---|---|---|
| [batch, pos, d_model] | Residual before attention |
| [batch, pos, d_model] | Residual after attention |
| [batch, pos, d_model] | Residual after MLP |
| [batch, pos, d_model] | Attention output |
| [batch, pos, d_model] | MLP output |
| [batch, head, q_pos, k_pos] | Attention pattern (post-softmax) |
| [batch, pos, head, d_head] | Query vectors |
| [batch, pos, head, d_head] | Key vectors |
| [batch, pos, head, d_head] | Value vectors |
| 键值格式 | 形状 | 描述 |
|---|---|---|
| [batch, pos, d_model] | 注意力层前的残差 |
| [batch, pos, d_model] | 注意力层后的残差 |
| [batch, pos, d_model] | MLP层后的残差 |
| [batch, pos, d_model] | 注意力输出 |
| [batch, pos, d_model] | MLP输出 |
| [batch, head, q_pos, k_pos] | 注意力模式(Softmax后) |
| [batch, pos, head, d_head] | 查询向量 |
| [batch, pos, head, d_head] | 键向量 |
| [batch, pos, head, d_head] | 值向量 |
Workflow 1: Activation Patching (Causal Tracing)
工作流1:激活修补(因果追踪)
Identify which activations causally affect model output by patching clean activations into corrupted runs.
通过将干净的激活值修补到损坏的模型运行过程中,确定哪些激活值会对模型输出产生因果影响。
Step-by-Step
步骤详解
python
from transformer_lens import HookedTransformer, patching
import torch
model = HookedTransformer.from_pretrained("gpt2-small")python
from transformer_lens import HookedTransformer, patching
import torch
model = HookedTransformer.from_pretrained("gpt2-small")1. Define clean and corrupted prompts
1. 定义干净和损坏的提示词
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
2. Get clean activations
2. 获取干净的激活值
_, clean_cache = model.run_with_cache(clean_tokens)
_, clean_cache = model.run_with_cache(clean_tokens)
3. Define metric (e.g., logit difference)
3. 定义评估指标(如对数几率差值)
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")
def metric(logits):
return logits[0, -1, paris_token] - logits[0, -1, rome_token]
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")
def metric(logits):
return logits[0, -1, paris_token] - logits[0, -1, rome_token]
4. Patch each position and layer
4. 对每个位置和层进行修补
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers):
for pos in range(clean_tokens.shape[1]):
def patch_hook(activation, hook):
activation[0, pos] = clean_cache[hook.name][0, pos]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results[layer, pos] = metric(patched_logits)results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers):
for pos in range(clean_tokens.shape[1]):
def patch_hook(activation, hook):
activation[0, pos] = clean_cache[hook.name][0, pos]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results[layer, pos] = metric(patched_logits)5. Visualize results (layer x position heatmap)
5. 可视化结果(层×位置热力图)
undefinedundefinedChecklist
检查清单
- Define clean and corrupted inputs that differ minimally
- Choose metric that captures behavior difference
- Cache clean activations
- Systematically patch each (layer, position) combination
- Visualize results as heatmap
- Identify causal hotspots
- 定义差异极小的干净和损坏输入
- 选择能捕捉行为差异的评估指标
- 缓存干净的激活值
- 系统性地对每个(层,位置)组合进行修补
- 将结果可视化为热力图
- 识别因果关键点
Workflow 2: Circuit Analysis (Indirect Object Identification)
工作流2:电路分析(间接对象识别)
Replicate the IOI circuit discovery from "Interpretability in the Wild".
复现《Interpretability in the Wild》论文中的IOI电路发现过程。
Step-by-Step
步骤详解
python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")IOI task: "When John and Mary went to the store, Mary gave a bottle to"
IOI任务:"When John and Mary went to the store, Mary gave a bottle to"
Model should predict "John" (indirect object)
模型应预测"John"(间接对象)
prompt = "When John and Mary went to the store, Mary gave a bottle to"
tokens = model.to_tokens(prompt)
prompt = "When John and Mary went to the store, Mary gave a bottle to"
tokens = model.to_tokens(prompt)
1. Get baseline logits
1. 获取基准对数几率
logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John")
mary_token = model.to_single_token(" Mary")
logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John")
mary_token = model.to_single_token(" Mary")
2. Compute logit difference (IO - S)
2. 计算对数几率差值(IO - S)
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token]
print(f"Logit difference: {logit_diff.item():.3f}")
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token]
print(f"Logit difference: {logit_diff.item():.3f}")
3. Direct logit attribution by head
3. 通过注意力头进行直接对数几率归因
def get_head_contribution(layer, head):
# Project head output to logits
head_out = cache["z", layer][0, :, head, :] # [pos, d_head]
W_O = model.W_O[layer, head] # [d_head, d_model]
W_U = model.W_U # [d_model, vocab]
# Head contribution to logits at final position
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]def get_head_contribution(layer, head):
# 将头输出投影到对数几率
head_out = cache["z", layer][0, :, head, :] # [pos, d_head]
W_O = model.W_O[layer, head] # [d_head, d_model]
W_U = model.W_U # [d_model, vocab]
# 注意力头对最终位置对数几率的贡献
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]4. Map all heads
4. 映射所有注意力头
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
head_contributions[layer, head] = get_head_contribution(layer, head)
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
head_contributions[layer, head] = get_head_contribution(layer, head)
5. Identify top contributing heads (name movers, backup name movers)
5. 识别贡献最大的注意力头(名称移动头、备份名称移动头)
undefinedundefinedChecklist
检查清单
- Set up task with clear IO/S tokens
- Compute baseline logit difference
- Decompose by attention head contributions
- Identify key circuit components (name movers, S-inhibition, induction)
- Validate with ablation experiments
- 设置包含明确IO/S token的任务
- 计算基准对数几率差值
- 按注意力头贡献进行分解
- 识别关键电路组件(名称移动头、S抑制、归纳)
- 通过消融实验验证
Workflow 3: Induction Head Detection
工作流3:归纳头检测
Find induction heads that implement [A][B]...[A] → [B] pattern.
python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")找到实现[A][B]...[A]→[B]模式的归纳头。
python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")Create repeated sequence: [A][B][A] should predict [B]
创建重复序列:[A][B][A]应预测[B]
repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # Arbitrary tokens
_, cache = model.run_with_cache(repeated_tokens)
repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # 任意token
_, cache = model.run_with_cache(repeated_tokens)
Induction heads attend from final [A] back to first [B]
归纳头会从最后一个[A]回溯到第一个[B]
Check attention from position 2 to position 1
检查位置2到位置1的注意力
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
pattern = cache["pattern", layer][0] # [head, q_pos, k_pos]
# Attention from pos 2 to pos 1
induction_scores[layer] = pattern[:, 2, 1]
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
pattern = cache["pattern", layer][0] # [head, q_pos, k_pos]
# 位置2到位置1的注意力
induction_scores[layer] = pattern[:, 2, 1]
Heads with high scores are induction heads
得分高的头即为归纳头
top_heads = torch.topk(induction_scores.flatten(), k=5)
undefinedtop_heads = torch.topk(induction_scores.flatten(), k=5)
undefinedCommon Issues & Solutions
常见问题与解决方案
Issue: Hooks persist after debugging
问题:调试后钩子仍保留
python
undefinedpython
undefinedWRONG: Old hooks remain active
错误做法:旧钩子仍会生效
model.run_with_hooks(tokens, fwd_hooks=[...]) # Debug, add new hooks
model.run_with_hooks(tokens, fwd_hooks=[...]) # Old hooks still there!
model.run_with_hooks(tokens, fwd_hooks=[...]) # 调试时添加新钩子
model.run_with_hooks(tokens, fwd_hooks=[...]) # 旧钩子依然存在!
RIGHT: Always reset hooks
正确做法:始终重置钩子
model.reset_hooks()
model.run_with_hooks(tokens, fwd_hooks=[...])
undefinedmodel.reset_hooks()
model.run_with_hooks(tokens, fwd_hooks=[...])
undefinedIssue: Tokenization gotchas
问题:分词陷阱
python
undefinedpython
undefinedWRONG: Assuming consistent tokenization
错误做法:假设分词结果一致
model.to_tokens("Tim") # Single token
model.to_tokens("Neel") # Becomes "Ne" + "el" (two tokens!)
model.to_tokens("Tim") # 单个token
model.to_tokens("Neel") # 会被分为"Ne" + "el"(两个token!)
RIGHT: Check tokenization explicitly
正确做法:显式检查分词结果
tokens = model.to_tokens("Neel", prepend_bos=False)
print(model.to_str_tokens(tokens)) # ['Ne', 'el']
undefinedtokens = model.to_tokens("Neel", prepend_bos=False)
print(model.to_str_tokens(tokens)) # ['Ne', 'el']
undefinedIssue: LayerNorm ignored in analysis
问题:分析中忽略LayerNorm
python
undefinedpython
undefinedWRONG: Ignoring LayerNorm
错误做法:忽略LayerNorm
pre_activation = residual @ model.W_in[layer]
pre_activation = residual @ model.W_in[layer]
RIGHT: Include LayerNorm
正确做法:包含LayerNorm
ln_scale = model.blocks[layer].ln2.w
ln_out = model.blocks[layer].ln2(residual)
pre_activation = ln_out @ model.W_in[layer]
undefinedln_scale = model.blocks[layer].ln2.w
ln_out = model.blocks[layer].ln2(residual)
pre_activation = ln_out @ model.W_in[layer]
undefinedIssue: Memory explosion with large models
问题:大模型导致内存溢出
python
undefinedpython
undefinedUse selective caching
使用选择性缓存
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda n: "resid_post" in n or "pattern" in n,
device="cpu" # Cache on CPU
)
undefinedlogits, cache = model.run_with_cache(
tokens,
names_filter=lambda n: "resid_post" in n or "pattern" in n,
device="cpu" # 在CPU上缓存
)
undefinedKey Classes Reference
核心类参考
| Class | Purpose |
|---|---|
| Main model wrapper with hooks |
| Dictionary-like cache of activations |
| Model configuration |
| Efficient factored matrix operations |
| 类 | 用途 |
|---|---|
| 带钩子的主模型包装类 |
| 类字典结构的激活缓存 |
| 模型配置类 |
| 高效的分解矩阵操作类 |
Integration with SAELens
与SAELens的集成
TransformerLens integrates with SAELens for Sparse Autoencoder analysis:
python
from transformer_lens import HookedTransformer
from sae_lens import SAE
model = HookedTransformer.from_pretrained("gpt2-small")
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")TransformerLens可与SAELens集成以进行稀疏自编码器分析:
python
from transformer_lens import HookedTransformer
from sae_lens import SAE
model = HookedTransformer.from_pretrained("gpt2-small")
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")Run with SAE
结合SAELens运行
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
sae_acts = sae.encode(cache["resid_pre", 8])
undefinedtokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
sae_acts = sae.encode(cache["resid_pre", 8])
undefinedReference Documentation
参考文档
For detailed API documentation, tutorials, and advanced usage, see the folder:
references/| File | Contents |
|---|---|
| references/README.md | Overview and quick start guide |
| references/api.md | Complete API reference for HookedTransformer, ActivationCache, HookPoints |
| references/tutorials.md | Step-by-step tutorials for activation patching, circuit analysis, logit lens |
如需详细的API文档、教程和高级用法,请查看文件夹:
references/| 文件 | 内容 |
|---|---|
| references/README.md | 概览与快速入门指南 |
| references/api.md | HookedTransformer、ActivationCache、HookPoints的完整API参考 |
| references/tutorials.md | 激活修补、电路分析、Logit Lens的分步教程 |
External Resources
外部资源
Tutorials
教程
- Main Demo Notebook
- Activation Patching Demo
- ARENA Mech Interp Course - 200+ hours of tutorials
- 主演示笔记本
- 激活修补演示
- ARENA机械可解释性课程 - 200+小时的教程内容
Papers
论文
Official Documentation
官方文档
Version Notes
版本说明
- v2.0: Removed HookedSAE (moved to SAELens)
- v3.0 (alpha): TransformerBridge for loading any nn.Module
- v2.0:移除了HookedSAE(迁移至SAELens)
- v3.0(测试版):新增TransformerBridge用于加载任意nn.Module