transformer-lens-interpretability

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

TransformerLens: Mechanistic Interpretability for Transformers

TransformerLens:面向Transformer的机械可解释性工具

TransformerLens is the de facto standard library for mechanistic interpretability research on GPT-style language models. Created by Neel Nanda and maintained by Bryce Meyer, it provides clean interfaces to inspect and manipulate model internals via HookPoints on every activation.
GitHub: TransformerLensOrg/TransformerLens (2,900+ stars)
TransformerLens是GPT类语言模型机械可解释性研究领域的事实标准库,由Neel Nanda创建,Bryce Meyer维护,它提供简洁的接口,通过在每个激活点设置HookPoints来检查和操作模型内部结构。
GitHubTransformerLensOrg/TransformerLens(2900+星标)

When to Use TransformerLens

何时使用TransformerLens

Use TransformerLens when you need to:
  • Reverse-engineer algorithms learned during training
  • Perform activation patching / causal tracing experiments
  • Study attention patterns and information flow
  • Analyze circuits (e.g., induction heads, IOI circuit)
  • Cache and inspect intermediate activations
  • Apply direct logit attribution
Consider alternatives when:
  • You need to work with non-transformer architectures → Use nnsight or pyvene
  • You want to train/analyze Sparse Autoencoders → Use SAELens
  • You need remote execution on massive models → Use nnsight with NDIF
  • You want higher-level causal intervention abstractions → Use pyvene
在以下场景中使用TransformerLens:
  • 逆向工程模型在训练过程中学习到的算法
  • 进行激活修补/因果追踪实验
  • 研究注意力模式与信息流
  • 分析电路(如归纳头、IOI电路)
  • 缓存并检查中间激活值
  • 应用直接对数几率归因
在以下场景中考虑替代工具:
  • 处理非Transformer架构 → 使用nnsightpyvene
  • 训练/分析稀疏自编码器 → 使用SAELens
  • 在大规模模型上进行远程执行 → 结合NDIF使用nnsight
  • 需要更高层级的因果干预抽象 → 使用pyvene

Installation

安装

bash
pip install transformer-lens
For development version:
bash
pip install git+https://github.com/TransformerLensOrg/TransformerLens
bash
pip install transformer-lens
开发版本安装:
bash
pip install git+https://github.com/TransformerLensOrg/TransformerLens

Core Concepts

核心概念

HookedTransformer

HookedTransformer

The main class that wraps transformer models with HookPoints on every activation:
python
from transformer_lens import HookedTransformer
这是核心类,用于包装Transformer模型,并在每个激活点设置HookPoints:
python
from transformer_lens import HookedTransformer

Load a model

加载模型

model = HookedTransformer.from_pretrained("gpt2-small")
model = HookedTransformer.from_pretrained("gpt2-small")

For gated models (LLaMA, Mistral)

加载 gated 模型(如LLaMA、Mistral)

import os os.environ["HF_TOKEN"] = "your_token" model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
undefined
import os os.environ["HF_TOKEN"] = "your_token" model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
undefined

Supported Models (50+)

支持的模型(50+种)

FamilyModels
GPT-2gpt2, gpt2-medium, gpt2-large, gpt2-xl
LLaMAllama-7b, llama-13b, llama-2-7b, llama-2-13b
EleutherAIpythia-70m to pythia-12b, gpt-neo, gpt-j-6b
Mistralmistral-7b, mixtral-8x7b
Othersphi, qwen, opt, gemma
模型家族具体模型
GPT-2gpt2, gpt2-medium, gpt2-large, gpt2-xl
LLaMAllama-7b, llama-13b, llama-2-7b, llama-2-13b
EleutherAIpythia-70m to pythia-12b, gpt-neo, gpt-j-6b
Mistralmistral-7b, mixtral-8x7b
其他phi, qwen, opt, gemma

Activation Caching

激活缓存(Activation Caching)

Run the model and cache all intermediate activations:
python
undefined
运行模型并缓存所有中间激活值:
python
undefined

Get all activations

获取所有激活值

tokens = model.to_tokens("The Eiffel Tower is in") logits, cache = model.run_with_cache(tokens)
tokens = model.to_tokens("The Eiffel Tower is in") logits, cache = model.run_with_cache(tokens)

Access specific activations

访问特定激活值

residual = cache["resid_post", 5] # Layer 5 residual stream attn_pattern = cache["pattern", 3] # Layer 3 attention pattern mlp_out = cache["mlp_out", 7] # Layer 7 MLP output
residual = cache["resid_post", 5] # 第5层的残差流 attn_pattern = cache["pattern", 3] # 第3层的注意力模式 mlp_out = cache["mlp_out", 7] # 第7层的MLP输出

Filter which activations to cache (saves memory)

过滤需要缓存的激活值(节省内存)

logits, cache = model.run_with_cache( tokens, names_filter=lambda name: "resid_post" in name )
undefined
logits, cache = model.run_with_cache( tokens, names_filter=lambda name: "resid_post" in name )
undefined

ActivationCache Keys

ActivationCache 键值说明

Key PatternShapeDescription
resid_pre, layer
[batch, pos, d_model]Residual before attention
resid_mid, layer
[batch, pos, d_model]Residual after attention
resid_post, layer
[batch, pos, d_model]Residual after MLP
attn_out, layer
[batch, pos, d_model]Attention output
mlp_out, layer
[batch, pos, d_model]MLP output
pattern, layer
[batch, head, q_pos, k_pos]Attention pattern (post-softmax)
q, layer
[batch, pos, head, d_head]Query vectors
k, layer
[batch, pos, head, d_head]Key vectors
v, layer
[batch, pos, head, d_head]Value vectors
键值格式形状描述
resid_pre, layer
[batch, pos, d_model]注意力层前的残差
resid_mid, layer
[batch, pos, d_model]注意力层后的残差
resid_post, layer
[batch, pos, d_model]MLP层后的残差
attn_out, layer
[batch, pos, d_model]注意力输出
mlp_out, layer
[batch, pos, d_model]MLP输出
pattern, layer
[batch, head, q_pos, k_pos]注意力模式(Softmax后)
q, layer
[batch, pos, head, d_head]查询向量
k, layer
[batch, pos, head, d_head]键向量
v, layer
[batch, pos, head, d_head]值向量

Workflow 1: Activation Patching (Causal Tracing)

工作流1:激活修补(因果追踪)

Identify which activations causally affect model output by patching clean activations into corrupted runs.
通过将干净的激活值修补到损坏的模型运行过程中,确定哪些激活值会对模型输出产生因果影响。

Step-by-Step

步骤详解

python
from transformer_lens import HookedTransformer, patching
import torch

model = HookedTransformer.from_pretrained("gpt2-small")
python
from transformer_lens import HookedTransformer, patching
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

1. Define clean and corrupted prompts

1. 定义干净和损坏的提示词

clean_prompt = "The Eiffel Tower is in the city of" corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt) corrupted_tokens = model.to_tokens(corrupted_prompt)
clean_prompt = "The Eiffel Tower is in the city of" corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt) corrupted_tokens = model.to_tokens(corrupted_prompt)

2. Get clean activations

2. 获取干净的激活值

_, clean_cache = model.run_with_cache(clean_tokens)
_, clean_cache = model.run_with_cache(clean_tokens)

3. Define metric (e.g., logit difference)

3. 定义评估指标(如对数几率差值)

paris_token = model.to_single_token(" Paris") rome_token = model.to_single_token(" Rome")
def metric(logits): return logits[0, -1, paris_token] - logits[0, -1, rome_token]
paris_token = model.to_single_token(" Paris") rome_token = model.to_single_token(" Rome")
def metric(logits): return logits[0, -1, paris_token] - logits[0, -1, rome_token]

4. Patch each position and layer

4. 对每个位置和层进行修补

results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers): for pos in range(clean_tokens.shape[1]): def patch_hook(activation, hook): activation[0, pos] = clean_cache[hook.name][0, pos] return activation
    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
    )
    results[layer, pos] = metric(patched_logits)
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers): for pos in range(clean_tokens.shape[1]): def patch_hook(activation, hook): activation[0, pos] = clean_cache[hook.name][0, pos] return activation
    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
    )
    results[layer, pos] = metric(patched_logits)

5. Visualize results (layer x position heatmap)

5. 可视化结果(层×位置热力图)

undefined
undefined

Checklist

检查清单

  • Define clean and corrupted inputs that differ minimally
  • Choose metric that captures behavior difference
  • Cache clean activations
  • Systematically patch each (layer, position) combination
  • Visualize results as heatmap
  • Identify causal hotspots
  • 定义差异极小的干净和损坏输入
  • 选择能捕捉行为差异的评估指标
  • 缓存干净的激活值
  • 系统性地对每个(层,位置)组合进行修补
  • 将结果可视化为热力图
  • 识别因果关键点

Workflow 2: Circuit Analysis (Indirect Object Identification)

工作流2:电路分析(间接对象识别)

Replicate the IOI circuit discovery from "Interpretability in the Wild".
复现《Interpretability in the Wild》论文中的IOI电路发现过程。

Step-by-Step

步骤详解

python
from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")
python
from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

IOI task: "When John and Mary went to the store, Mary gave a bottle to"

IOI任务:"When John and Mary went to the store, Mary gave a bottle to"

Model should predict "John" (indirect object)

模型应预测"John"(间接对象)

prompt = "When John and Mary went to the store, Mary gave a bottle to" tokens = model.to_tokens(prompt)
prompt = "When John and Mary went to the store, Mary gave a bottle to" tokens = model.to_tokens(prompt)

1. Get baseline logits

1. 获取基准对数几率

logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John") mary_token = model.to_single_token(" Mary")
logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John") mary_token = model.to_single_token(" Mary")

2. Compute logit difference (IO - S)

2. 计算对数几率差值(IO - S)

logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token] print(f"Logit difference: {logit_diff.item():.3f}")
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token] print(f"Logit difference: {logit_diff.item():.3f}")

3. Direct logit attribution by head

3. 通过注意力头进行直接对数几率归因

def get_head_contribution(layer, head): # Project head output to logits head_out = cache["z", layer][0, :, head, :] # [pos, d_head] W_O = model.W_O[layer, head] # [d_head, d_model] W_U = model.W_U # [d_model, vocab]
# Head contribution to logits at final position
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]
def get_head_contribution(layer, head): # 将头输出投影到对数几率 head_out = cache["z", layer][0, :, head, :] # [pos, d_head] W_O = model.W_O[layer, head] # [d_head, d_model] W_U = model.W_U # [d_model, vocab]
# 注意力头对最终位置对数几率的贡献
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]

4. Map all heads

4. 映射所有注意力头

head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads) for layer in range(model.cfg.n_layers): for head in range(model.cfg.n_heads): head_contributions[layer, head] = get_head_contribution(layer, head)
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads) for layer in range(model.cfg.n_layers): for head in range(model.cfg.n_heads): head_contributions[layer, head] = get_head_contribution(layer, head)

5. Identify top contributing heads (name movers, backup name movers)

5. 识别贡献最大的注意力头(名称移动头、备份名称移动头)

undefined
undefined

Checklist

检查清单

  • Set up task with clear IO/S tokens
  • Compute baseline logit difference
  • Decompose by attention head contributions
  • Identify key circuit components (name movers, S-inhibition, induction)
  • Validate with ablation experiments
  • 设置包含明确IO/S token的任务
  • 计算基准对数几率差值
  • 按注意力头贡献进行分解
  • 识别关键电路组件(名称移动头、S抑制、归纳)
  • 通过消融实验验证

Workflow 3: Induction Head Detection

工作流3:归纳头检测

Find induction heads that implement [A][B]...[A] → [B] pattern.
python
from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")
找到实现[A][B]...[A]→[B]模式的归纳头。
python
from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

Create repeated sequence: [A][B][A] should predict [B]

创建重复序列:[A][B][A]应预测[B]

repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # Arbitrary tokens
_, cache = model.run_with_cache(repeated_tokens)
repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # 任意token
_, cache = model.run_with_cache(repeated_tokens)

Induction heads attend from final [A] back to first [B]

归纳头会从最后一个[A]回溯到第一个[B]

Check attention from position 2 to position 1

检查位置2到位置1的注意力

induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers): pattern = cache["pattern", layer][0] # [head, q_pos, k_pos] # Attention from pos 2 to pos 1 induction_scores[layer] = pattern[:, 2, 1]
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers): pattern = cache["pattern", layer][0] # [head, q_pos, k_pos] # 位置2到位置1的注意力 induction_scores[layer] = pattern[:, 2, 1]

Heads with high scores are induction heads

得分高的头即为归纳头

top_heads = torch.topk(induction_scores.flatten(), k=5)
undefined
top_heads = torch.topk(induction_scores.flatten(), k=5)
undefined

Common Issues & Solutions

常见问题与解决方案

Issue: Hooks persist after debugging

问题:调试后钩子仍保留

python
undefined
python
undefined

WRONG: Old hooks remain active

错误做法:旧钩子仍会生效

model.run_with_hooks(tokens, fwd_hooks=[...]) # Debug, add new hooks model.run_with_hooks(tokens, fwd_hooks=[...]) # Old hooks still there!
model.run_with_hooks(tokens, fwd_hooks=[...]) # 调试时添加新钩子 model.run_with_hooks(tokens, fwd_hooks=[...]) # 旧钩子依然存在!

RIGHT: Always reset hooks

正确做法:始终重置钩子

model.reset_hooks() model.run_with_hooks(tokens, fwd_hooks=[...])
undefined
model.reset_hooks() model.run_with_hooks(tokens, fwd_hooks=[...])
undefined

Issue: Tokenization gotchas

问题:分词陷阱

python
undefined
python
undefined

WRONG: Assuming consistent tokenization

错误做法:假设分词结果一致

model.to_tokens("Tim") # Single token model.to_tokens("Neel") # Becomes "Ne" + "el" (two tokens!)
model.to_tokens("Tim") # 单个token model.to_tokens("Neel") # 会被分为"Ne" + "el"(两个token!)

RIGHT: Check tokenization explicitly

正确做法:显式检查分词结果

tokens = model.to_tokens("Neel", prepend_bos=False) print(model.to_str_tokens(tokens)) # ['Ne', 'el']
undefined
tokens = model.to_tokens("Neel", prepend_bos=False) print(model.to_str_tokens(tokens)) # ['Ne', 'el']
undefined

Issue: LayerNorm ignored in analysis

问题:分析中忽略LayerNorm

python
undefined
python
undefined

WRONG: Ignoring LayerNorm

错误做法:忽略LayerNorm

pre_activation = residual @ model.W_in[layer]
pre_activation = residual @ model.W_in[layer]

RIGHT: Include LayerNorm

正确做法:包含LayerNorm

ln_scale = model.blocks[layer].ln2.w ln_out = model.blocks[layer].ln2(residual) pre_activation = ln_out @ model.W_in[layer]
undefined
ln_scale = model.blocks[layer].ln2.w ln_out = model.blocks[layer].ln2(residual) pre_activation = ln_out @ model.W_in[layer]
undefined

Issue: Memory explosion with large models

问题:大模型导致内存溢出

python
undefined
python
undefined

Use selective caching

使用选择性缓存

logits, cache = model.run_with_cache( tokens, names_filter=lambda n: "resid_post" in n or "pattern" in n, device="cpu" # Cache on CPU )
undefined
logits, cache = model.run_with_cache( tokens, names_filter=lambda n: "resid_post" in n or "pattern" in n, device="cpu" # 在CPU上缓存 )
undefined

Key Classes Reference

核心类参考

ClassPurpose
HookedTransformer
Main model wrapper with hooks
ActivationCache
Dictionary-like cache of activations
HookedTransformerConfig
Model configuration
FactoredMatrix
Efficient factored matrix operations
用途
HookedTransformer
带钩子的主模型包装类
ActivationCache
类字典结构的激活缓存
HookedTransformerConfig
模型配置类
FactoredMatrix
高效的分解矩阵操作类

Integration with SAELens

与SAELens的集成

TransformerLens integrates with SAELens for Sparse Autoencoder analysis:
python
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small")
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
TransformerLens可与SAELens集成以进行稀疏自编码器分析:
python
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small")
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")

Run with SAE

结合SAELens运行

tokens = model.to_tokens("Hello world") _, cache = model.run_with_cache(tokens) sae_acts = sae.encode(cache["resid_pre", 8])
undefined
tokens = model.to_tokens("Hello world") _, cache = model.run_with_cache(tokens) sae_acts = sae.encode(cache["resid_pre", 8])
undefined

Reference Documentation

参考文档

For detailed API documentation, tutorials, and advanced usage, see the
references/
folder:
FileContents
references/README.mdOverview and quick start guide
references/api.mdComplete API reference for HookedTransformer, ActivationCache, HookPoints
references/tutorials.mdStep-by-step tutorials for activation patching, circuit analysis, logit lens
如需详细的API文档、教程和高级用法,请查看
references/
文件夹:
文件内容
references/README.md概览与快速入门指南
references/api.mdHookedTransformer、ActivationCache、HookPoints的完整API参考
references/tutorials.md激活修补、电路分析、Logit Lens的分步教程

External Resources

外部资源

Tutorials

教程

Papers

论文

Official Documentation

官方文档

Version Notes

版本说明

  • v2.0: Removed HookedSAE (moved to SAELens)
  • v3.0 (alpha): TransformerBridge for loading any nn.Module
  • v2.0:移除了HookedSAE(迁移至SAELens)
  • v3.0(测试版):新增TransformerBridge用于加载任意nn.Module