pyvene-interventions
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
Chinesepyvene: Causal Interventions for Neural Networks
pyvene:神经网络的因果干预
pyvene is Stanford NLP's library for performing causal interventions on PyTorch models. It provides a declarative, dict-based framework for activation patching, causal tracing, and interchange intervention training - making intervention experiments reproducible and shareable.
GitHub: stanfordnlp/pyvene (840+ stars)
Paper: pyvene: A Library for Understanding and Improving PyTorch Models via Interventions (NAACL 2024)
pyvene是斯坦福NLP团队开发的用于对PyTorch模型进行因果干预的库。它提供了一个基于字典的声明式框架,支持激活补丁、因果追踪和互换干预训练——让干预实验可复现且便于分享。
GitHub:stanfordnlp/pyvene(840+星)
论文:pyvene: A Library for Understanding and Improving PyTorch Models via Interventions(NAACL 2024)
When to Use pyvene
何时使用pyvene
Use pyvene when you need to:
- Perform causal tracing (ROME-style localization)
- Run activation patching experiments
- Conduct interchange intervention training (IIT)
- Test causal hypotheses about model components
- Share/reproduce intervention experiments via HuggingFace
- Work with any PyTorch architecture (not just transformers)
Consider alternatives when:
- You need exploratory activation analysis → Use TransformerLens
- You want to train/analyze SAEs → Use SAELens
- You need remote execution on massive models → Use nnsight
- You want lower-level control → Use nnsight
在以下场景使用pyvene:
- 执行因果追踪(ROME式定位)
- 运行激活补丁实验
- 进行互换干预训练(IIT)
- 测试模型组件的因果假设
- 通过HuggingFace分享/复现干预实验
- 适用于任何PyTorch架构(不仅限于Transformer)
考虑使用替代工具的场景:
- 需要探索性激活分析 → 使用TransformerLens
- 想要训练/分析SAE → 使用SAELens
- 需要在大规模模型上进行远程执行 → 使用nnsight
- 需要更底层的控制 → 使用nnsight
Installation
安装
bash
pip install pyveneStandard import:
python
import pyvene as pvbash
pip install pyvene标准导入:
python
import pyvene as pvCore Concepts
核心概念
IntervenableModel
IntervenableModel
The main class that wraps any PyTorch model with intervention capabilities:
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer为任意PyTorch模型添加干预能力的主类:
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizerLoad base model
加载基础模型
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
Define intervention configuration
定义干预配置
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
Create intervenable model
创建可干预模型
intervenable = pv.IntervenableModel(config, model)
undefinedintervenable = pv.IntervenableModel(config, model)
undefinedIntervention Types
干预类型
| Type | Description | Use Case |
|---|---|---|
| Swap activations between runs | Activation patching |
| Add activations to base run | Steering, ablation |
| Subtract activations | Ablation |
| Zero out activations | Component knockout |
| DAS trainable intervention | Causal discovery |
| Collect activations | Probing, analysis |
| 类型 | 描述 | 使用场景 |
|---|---|---|
| 在不同运行之间交换激活值 | Activation patching |
| 向基础运行添加激活值 | 模型引导、消融实验 |
| 减去激活值 | 消融实验 |
| 将激活值置零 | 组件剔除实验 |
| 可训练的DAS干预 | 因果发现 |
| 收集激活值 | 探针分析、特征研究 |
Component Targets
组件目标
python
undefinedpython
undefinedAvailable components to intervene on
可进行干预的组件列表
components = [
"block_input", # Input to transformer block
"block_output", # Output of transformer block
"mlp_input", # Input to MLP
"mlp_output", # Output of MLP
"mlp_activation", # MLP hidden activations
"attention_input", # Input to attention
"attention_output", # Output of attention
"attention_value_output", # Attention value vectors
"query_output", # Query vectors
"key_output", # Key vectors
"value_output", # Value vectors
"head_attention_value_output", # Per-head values
]
undefinedcomponents = [
"block_input", # Transformer块的输入
"block_output", # Transformer块的输出
"mlp_input", # MLP的输入
"mlp_output", # MLP的输出
"mlp_activation", # MLP的隐藏层激活值
"attention_input", # 注意力机制的输入
"attention_output", # 注意力机制的输出
"attention_value_output", # 注意力机制的向量值
"query_output", # 查询向量
"key_output", # 键向量
"value_output", # 值向量
"head_attention_value_output", # 单头注意力的向量值
]
undefinedWorkflow 1: Causal Tracing (ROME-style)
工作流1:因果追踪(ROME风格)
Locate where factual associations are stored by corrupting inputs and restoring activations.
通过破坏输入并恢复激活值,定位事实关联的存储位置。
Step-by-Step
分步指南
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")1. Define clean and corrupted inputs
1. 定义干净和被破坏的输入
clean_prompt = "The Space Needle is in downtown"
corrupted_prompt = "The ##### ###### ## ## ########" # Noise
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
clean_prompt = "The Space Needle is in downtown"
corrupted_prompt = "The ##### ###### ## ## ########" # 噪声输入
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
2. Get clean activations (source)
2. 获取干净的激活值(源数据)
with torch.no_grad():
clean_outputs = model(**clean_tokens, output_hidden_states=True)
clean_states = clean_outputs.hidden_states
with torch.no_grad():
clean_outputs = model(**clean_tokens, output_hidden_states=True)
clean_states = clean_outputs.hidden_states
3. Define restoration intervention
3. 定义恢复干预
def run_causal_trace(layer, position):
"""Restore clean activation at specific layer and position."""
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="block_output",
intervention_type=pv.VanillaIntervention,
unit="pos",
max_number_of_units=1,
)
]
)
intervenable = pv.IntervenableModel(config, model)
# Run with intervention
_, patched_outputs = intervenable(
base=corrupted_tokens,
sources=[clean_tokens],
unit_locations={"sources->base": ([[[position]]], [[[position]]])},
output_original_output=True,
)
# Return probability of correct token
probs = torch.softmax(patched_outputs.logits[0, -1], dim=-1)
seattle_token = tokenizer.encode(" Seattle")[0]
return probs[seattle_token].item()def run_causal_trace(layer, position):
"""恢复特定层和位置的干净激活值。"""
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="block_output",
intervention_type=pv.VanillaIntervention,
unit="pos",
max_number_of_units=1,
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 执行干预
_, patched_outputs = intervenable(
base=corrupted_tokens,
sources=[clean_tokens],
unit_locations={"sources->base": ([[[position]]], [[[position]]])},
output_original_output=True,
)
# 返回正确token的概率
probs = torch.softmax(patched_outputs.logits[0, -1], dim=-1)
seattle_token = tokenizer.encode(" Seattle")[0]
return probs[seattle_token].item()4. Sweep over layers and positions
4. 遍历所有层和位置
n_layers = model.config.n_layer
seq_len = clean_tokens["input_ids"].shape[1]
results = torch.zeros(n_layers, seq_len)
for layer in range(n_layers):
for pos in range(seq_len):
results[layer, pos] = run_causal_trace(layer, pos)
n_layers = model.config.n_layer
seq_len = clean_tokens["input_ids"].shape[1]
results = torch.zeros(n_layers, seq_len)
for layer in range(n_layers):
for pos in range(seq_len):
results[layer, pos] = run_causal_trace(layer, pos)
5. Visualize (layer x position heatmap)
5. 可视化(层×位置热力图)
High values indicate causal importance
数值越高表示因果重要性越强
undefinedundefinedChecklist
检查清单
- Prepare clean prompt with target factual association
- Create corrupted version (noise or counterfactual)
- Define intervention config for each (layer, position)
- Run patching sweep
- Identify causal hotspots in heatmap
- 准备包含目标事实关联的干净提示词
- 创建被破坏的版本(噪声或反事实输入)
- 为每个(层,位置)对定义干预配置
- 运行补丁扫描
- 在热力图中识别因果热点
Workflow 2: Activation Patching for Circuit Analysis
工作流2:用于电路分析的激活补丁
Test which components are necessary for a specific behavior.
测试哪些组件对特定行为是必要的。
Step-by-Step
分步指南
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")IOI task setup
IOI任务设置
clean_prompt = "When John and Mary went to the store, Mary gave a bottle to"
corrupted_prompt = "When John and Mary went to the store, John gave a bottle to"
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
john_token = tokenizer.encode(" John")[0]
mary_token = tokenizer.encode(" Mary")[0]
def logit_diff(logits):
"""IO - S logit difference."""
return logits[0, -1, john_token] - logits[0, -1, mary_token]
clean_prompt = "When John and Mary went to the store, Mary gave a bottle to"
corrupted_prompt = "When John and Mary went to the store, John gave a bottle to"
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
john_token = tokenizer.encode(" John")[0]
mary_token = tokenizer.encode(" Mary")[0]
def logit_diff(logits):
"""IO - S logit差值。"""
return logits[0, -1, john_token] - logits[0, -1, mary_token]
Patch attention output at each layer
对每层的注意力输出进行补丁
def patch_attention(layer):
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="attention_output",
intervention_type=pv.VanillaIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
_, patched_outputs = intervenable(
base=corrupted_tokens,
sources=[clean_tokens],
)
return logit_diff(patched_outputs.logits).item()def patch_attention(layer):
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="attention_output",
intervention_type=pv.VanillaIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
_, patched_outputs = intervenable(
base=corrupted_tokens,
sources=[clean_tokens],
)
return logit_diff(patched_outputs.logits).item()Find which layers matter
找出关键层
results = []
for layer in range(model.config.n_layer):
diff = patch_attention(layer)
results.append(diff)
print(f"Layer {layer}: logit diff = {diff:.3f}")
undefinedresults = []
for layer in range(model.config.n_layer):
diff = patch_attention(layer)
results.append(diff)
print(f"Layer {layer}: logit diff = {diff:.3f}")
undefinedWorkflow 3: Interchange Intervention Training (IIT)
工作流3:互换干预训练(IIT)
Train interventions to discover causal structure.
训练干预以发现因果结构。
Step-by-Step
分步指南
python
import pyvene as pv
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")python
import pyvene as pv
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")1. Define trainable intervention
1. 定义可训练干预
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=6,
component="block_output",
intervention_type=pv.RotatedSpaceIntervention, # Trainable
low_rank_dimension=64, # Learn 64-dim subspace
)
]
)
intervenable = pv.IntervenableModel(config, model)
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=6,
component="block_output",
intervention_type=pv.RotatedSpaceIntervention, # 可训练
low_rank_dimension=64, # 学习64维子空间
)
]
)
intervenable = pv.IntervenableModel(config, model)
2. Set up training
2. 设置训练
optimizer = torch.optim.Adam(
intervenable.get_trainable_parameters(),
lr=1e-4
)
optimizer = torch.optim.Adam(
intervenable.get_trainable_parameters(),
lr=1e-4
)
3. Training loop (simplified)
3. 训练循环(简化版)
for base_input, source_input, target_output in dataloader:
optimizer.zero_grad()
_, outputs = intervenable(
base=base_input,
sources=[source_input],
)
loss = criterion(outputs.logits, target_output)
loss.backward()
optimizer.step()for base_input, source_input, target_output in dataloader:
optimizer.zero_grad()
_, outputs = intervenable(
base=base_input,
sources=[source_input],
)
loss = criterion(outputs.logits, target_output)
loss.backward()
optimizer.step()4. Analyze learned intervention
4. 分析学到的干预
The rotation matrix reveals causal subspace
旋转矩阵揭示了因果子空间
rotation = intervenable.interventions["layer.6.block_output"][0].rotate_layer
undefinedrotation = intervenable.interventions["layer.6.block_output"][0].rotate_layer
undefinedDAS (Distributed Alignment Search)
DAS(分布式对齐搜索)
python
undefinedpython
undefinedLow-rank rotation finds interpretable subspaces
低秩旋转寻找可解释的子空间
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.LowRankRotatedSpaceIntervention,
low_rank_dimension=1, # Find 1D causal direction
)
]
)
undefinedconfig = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.LowRankRotatedSpaceIntervention,
low_rank_dimension=1, # 寻找1维因果方向
)
]
)
undefinedWorkflow 4: Model Steering (Honest LLaMA)
工作流4:模型引导(Honest LLaMA)
Steer model behavior during generation.
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")在生成过程中引导模型行为。
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")Load pre-trained steering intervention
加载预训练的引导干预
intervenable = pv.IntervenableModel.load(
"zhengxuanzenwu/intervenable_honest_llama2_chat_7B",
model=model,
)
intervenable = pv.IntervenableModel.load(
"zhengxuanzenwu/intervenable_honest_llama2_chat_7B",
model=model,
)
Generate with steering
带引导的生成
prompt = "Is the earth flat?"
inputs = tokenizer(prompt, return_tensors="pt")
prompt = "Is the earth flat?"
inputs = tokenizer(prompt, return_tensors="pt")
Intervention applied during generation
生成过程中应用干预
outputs = intervenable.generate(
inputs,
max_new_tokens=100,
do_sample=False,
)
print(tokenizer.decode(outputs[0]))
undefinedoutputs = intervenable.generate(
inputs,
max_new_tokens=100,
do_sample=False,
)
print(tokenizer.decode(outputs[0]))
undefinedSaving and Sharing Interventions
保存与分享干预
python
undefinedpython
undefinedSave locally
本地保存
intervenable.save("./my_intervention")
intervenable.save("./my_intervention")
Load from local
从本地加载
intervenable = pv.IntervenableModel.load(
"./my_intervention",
model=model,
)
intervenable = pv.IntervenableModel.load(
"./my_intervention",
model=model,
)
Share on HuggingFace
在HuggingFace上分享
intervenable.save_intervention("username/my-intervention")
intervenable.save_intervention("username/my-intervention")
Load from HuggingFace
从HuggingFace加载
intervenable = pv.IntervenableModel.load(
"username/my-intervention",
model=model,
)
undefinedintervenable = pv.IntervenableModel.load(
"username/my-intervention",
model=model,
)
undefinedCommon Issues & Solutions
常见问题与解决方案
Issue: Wrong intervention location
问题:干预位置错误
python
undefinedpython
undefinedWRONG: Incorrect component name
错误:组件名称不正确
config = pv.RepresentationConfig(
component="mlp", # Not valid!
)
config = pv.RepresentationConfig(
component="mlp", # 无效!
)
RIGHT: Use exact component name
正确:使用精确的组件名称
config = pv.RepresentationConfig(
component="mlp_output", # Valid
)
undefinedconfig = pv.RepresentationConfig(
component="mlp_output", # 有效
)
undefinedIssue: Dimension mismatch
问题:维度不匹配
python
undefinedpython
undefinedEnsure source and base have compatible shapes
确保源输入和基础输入的形状兼容
For position-specific interventions:
针对特定位置的干预:
config = pv.RepresentationConfig(
unit="pos",
max_number_of_units=1, # Intervene on single position
)
config = pv.RepresentationConfig(
unit="pos",
max_number_of_units=1, # 对单个位置进行干预
)
Specify locations explicitly
显式指定位置
intervenable(
base=base_tokens,
sources=[source_tokens],
unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # Position 5
)
undefinedintervenable(
base=base_tokens,
sources=[source_tokens],
unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # 位置5
)
undefinedIssue: Memory with large models
问题:大模型内存不足
python
undefinedpython
undefinedUse gradient checkpointing
使用梯度检查点
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable()
Or intervene on fewer components
或者减少干预组件数量
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8, # Single layer instead of all
component="block_output",
)
]
)
undefinedconfig = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8, # 仅干预单个层而非所有层
component="block_output",
)
]
)
undefinedIssue: LoRA integration
问题:LoRA集成
python
undefinedpython
undefinedpyvene v0.1.8+ supports LoRAs as interventions
pyvene v0.1.8+支持将LoRA作为干预
config = pv.RepresentationConfig(
intervention_type=pv.LoRAIntervention,
low_rank_dimension=16,
)
undefinedconfig = pv.RepresentationConfig(
intervention_type=pv.LoRAIntervention,
low_rank_dimension=16,
)
undefinedKey Classes Reference
核心类参考
| Class | Purpose |
|---|---|
| Main wrapper for interventions |
| Configuration container |
| Single intervention specification |
| Activation swapping |
| Trainable DAS intervention |
| Activation collection |
| 类 | 用途 |
|---|---|
| 干预功能的主包装类 |
| 配置容器 |
| 单个干预的规格说明 |
| 激活值交换 |
| 可训练的DAS干预 |
| 激活值收集 |
Supported Models
支持的模型
pyvene works with any PyTorch model. Tested on:
- GPT-2 (all sizes)
- LLaMA / LLaMA-2
- Pythia
- Mistral / Mixtral
- OPT
- BLIP (vision-language)
- ESM (protein models)
- Mamba (state space)
pyvene可用于任意PyTorch模型,已测试的模型包括:
- GPT-2(全尺寸)
- LLaMA / LLaMA-2
- Pythia
- Mistral / Mixtral
- OPT
- BLIP(视觉-语言模型)
- ESM(蛋白质模型)
- Mamba(状态空间模型)
Reference Documentation
参考文档
For detailed API documentation, tutorials, and advanced usage, see the folder:
references/| File | Contents |
|---|---|
| references/README.md | Overview and quick start guide |
| references/api.md | Complete API reference for IntervenableModel, intervention types, configurations |
| references/tutorials.md | Step-by-step tutorials for causal tracing, activation patching, DAS |
如需详细的API文档、教程和高级用法,请查看文件夹:
references/| 文件 | 内容 |
|---|---|
| references/README.md | 概览与快速入门指南 |
| references/api.md | IntervenableModel、干预类型、配置的完整API参考 |
| references/tutorials.md | 因果追踪、激活补丁、DAS的分步教程 |
External Resources
与其他工具的对比
Tutorials
—
| 特性 | pyvene | TransformerLens | nnsight |
|---|---|---|---|
| 声明式配置 | 是 | 否 | 否 |
| HuggingFace分享 | 是 | 否 | 否 |
| 可训练干预 | 是 | 有限支持 | 是 |
| 支持任意PyTorch模型 | 是 | 仅支持Transformer | 是 |
| 远程执行 | 否 | 否 | 是(NDIF) |
Papers
—
- Locating and Editing Factual Associations in GPT - Meng et al. (2022)
- Inference-Time Intervention - Li et al. (2023)
- Interpretability in the Wild - Wang et al. (2022)
—
Official Documentation
—
Comparison with Other Tools
—
| Feature | pyvene | TransformerLens | nnsight |
|---|---|---|---|
| Declarative config | Yes | No | No |
| HuggingFace sharing | Yes | No | No |
| Trainable interventions | Yes | Limited | Yes |
| Any PyTorch model | Yes | Transformers only | Yes |
| Remote execution | No | No | Yes (NDIF) |
—