pyvene-interventions

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

pyvene: Causal Interventions for Neural Networks

pyvene:神经网络的因果干预

pyvene is Stanford NLP's library for performing causal interventions on PyTorch models. It provides a declarative, dict-based framework for activation patching, causal tracing, and interchange intervention training - making intervention experiments reproducible and shareable.
pyvene是斯坦福NLP团队开发的用于对PyTorch模型进行因果干预的库。它提供了一个基于字典的声明式框架,支持激活补丁、因果追踪和互换干预训练——让干预实验可复现且便于分享。

When to Use pyvene

何时使用pyvene

Use pyvene when you need to:
  • Perform causal tracing (ROME-style localization)
  • Run activation patching experiments
  • Conduct interchange intervention training (IIT)
  • Test causal hypotheses about model components
  • Share/reproduce intervention experiments via HuggingFace
  • Work with any PyTorch architecture (not just transformers)
Consider alternatives when:
  • You need exploratory activation analysis → Use TransformerLens
  • You want to train/analyze SAEs → Use SAELens
  • You need remote execution on massive models → Use nnsight
  • You want lower-level control → Use nnsight
在以下场景使用pyvene
  • 执行因果追踪(ROME式定位)
  • 运行激活补丁实验
  • 进行互换干预训练(IIT)
  • 测试模型组件的因果假设
  • 通过HuggingFace分享/复现干预实验
  • 适用于任何PyTorch架构(不仅限于Transformer)
考虑使用替代工具的场景
  • 需要探索性激活分析 → 使用TransformerLens
  • 想要训练/分析SAE → 使用SAELens
  • 需要在大规模模型上进行远程执行 → 使用nnsight
  • 需要更底层的控制 → 使用nnsight

Installation

安装

bash
pip install pyvene
Standard import:
python
import pyvene as pv
bash
pip install pyvene
标准导入:
python
import pyvene as pv

Core Concepts

核心概念

IntervenableModel

IntervenableModel

The main class that wraps any PyTorch model with intervention capabilities:
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
为任意PyTorch模型添加干预能力的主类:
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer

Load base model

加载基础模型

model = AutoModelForCausalLM.from_pretrained("gpt2") tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2") tokenizer = AutoTokenizer.from_pretrained("gpt2")

Define intervention configuration

定义干预配置

config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=8, component="block_output", intervention_type=pv.VanillaIntervention, ) ] )
config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=8, component="block_output", intervention_type=pv.VanillaIntervention, ) ] )

Create intervenable model

创建可干预模型

intervenable = pv.IntervenableModel(config, model)
undefined
intervenable = pv.IntervenableModel(config, model)
undefined

Intervention Types

干预类型

TypeDescriptionUse Case
VanillaIntervention
Swap activations between runsActivation patching
AdditionIntervention
Add activations to base runSteering, ablation
SubtractionIntervention
Subtract activationsAblation
ZeroIntervention
Zero out activationsComponent knockout
RotatedSpaceIntervention
DAS trainable interventionCausal discovery
CollectIntervention
Collect activationsProbing, analysis
类型描述使用场景
VanillaIntervention
在不同运行之间交换激活值Activation patching
AdditionIntervention
向基础运行添加激活值模型引导、消融实验
SubtractionIntervention
减去激活值消融实验
ZeroIntervention
将激活值置零组件剔除实验
RotatedSpaceIntervention
可训练的DAS干预因果发现
CollectIntervention
收集激活值探针分析、特征研究

Component Targets

组件目标

python
undefined
python
undefined

Available components to intervene on

可进行干预的组件列表

components = [ "block_input", # Input to transformer block "block_output", # Output of transformer block "mlp_input", # Input to MLP "mlp_output", # Output of MLP "mlp_activation", # MLP hidden activations "attention_input", # Input to attention "attention_output", # Output of attention "attention_value_output", # Attention value vectors "query_output", # Query vectors "key_output", # Key vectors "value_output", # Value vectors "head_attention_value_output", # Per-head values ]
undefined
components = [ "block_input", # Transformer块的输入 "block_output", # Transformer块的输出 "mlp_input", # MLP的输入 "mlp_output", # MLP的输出 "mlp_activation", # MLP的隐藏层激活值 "attention_input", # 注意力机制的输入 "attention_output", # 注意力机制的输出 "attention_value_output", # 注意力机制的向量值 "query_output", # 查询向量 "key_output", # 键向量 "value_output", # 值向量 "head_attention_value_output", # 单头注意力的向量值 ]
undefined

Workflow 1: Causal Tracing (ROME-style)

工作流1:因果追踪(ROME风格)

Locate where factual associations are stored by corrupting inputs and restoring activations.
通过破坏输入并恢复激活值,定位事实关联的存储位置。

Step-by-Step

分步指南

python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")

1. Define clean and corrupted inputs

1. 定义干净和被破坏的输入

clean_prompt = "The Space Needle is in downtown" corrupted_prompt = "The ##### ###### ## ## ########" # Noise
clean_tokens = tokenizer(clean_prompt, return_tensors="pt") corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
clean_prompt = "The Space Needle is in downtown" corrupted_prompt = "The ##### ###### ## ## ########" # 噪声输入
clean_tokens = tokenizer(clean_prompt, return_tensors="pt") corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")

2. Get clean activations (source)

2. 获取干净的激活值(源数据)

with torch.no_grad(): clean_outputs = model(**clean_tokens, output_hidden_states=True) clean_states = clean_outputs.hidden_states
with torch.no_grad(): clean_outputs = model(**clean_tokens, output_hidden_states=True) clean_states = clean_outputs.hidden_states

3. Define restoration intervention

3. 定义恢复干预

def run_causal_trace(layer, position): """Restore clean activation at specific layer and position.""" config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=layer, component="block_output", intervention_type=pv.VanillaIntervention, unit="pos", max_number_of_units=1, ) ] )
intervenable = pv.IntervenableModel(config, model)

# Run with intervention
_, patched_outputs = intervenable(
    base=corrupted_tokens,
    sources=[clean_tokens],
    unit_locations={"sources->base": ([[[position]]], [[[position]]])},
    output_original_output=True,
)

# Return probability of correct token
probs = torch.softmax(patched_outputs.logits[0, -1], dim=-1)
seattle_token = tokenizer.encode(" Seattle")[0]
return probs[seattle_token].item()
def run_causal_trace(layer, position): """恢复特定层和位置的干净激活值。""" config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=layer, component="block_output", intervention_type=pv.VanillaIntervention, unit="pos", max_number_of_units=1, ) ] )
intervenable = pv.IntervenableModel(config, model)

# 执行干预
_, patched_outputs = intervenable(
    base=corrupted_tokens,
    sources=[clean_tokens],
    unit_locations={"sources->base": ([[[position]]], [[[position]]])},
    output_original_output=True,
)

# 返回正确token的概率
probs = torch.softmax(patched_outputs.logits[0, -1], dim=-1)
seattle_token = tokenizer.encode(" Seattle")[0]
return probs[seattle_token].item()

4. Sweep over layers and positions

4. 遍历所有层和位置

n_layers = model.config.n_layer seq_len = clean_tokens["input_ids"].shape[1]
results = torch.zeros(n_layers, seq_len) for layer in range(n_layers): for pos in range(seq_len): results[layer, pos] = run_causal_trace(layer, pos)
n_layers = model.config.n_layer seq_len = clean_tokens["input_ids"].shape[1]
results = torch.zeros(n_layers, seq_len) for layer in range(n_layers): for pos in range(seq_len): results[layer, pos] = run_causal_trace(layer, pos)

5. Visualize (layer x position heatmap)

5. 可视化(层×位置热力图)

High values indicate causal importance

数值越高表示因果重要性越强

undefined
undefined

Checklist

检查清单

  • Prepare clean prompt with target factual association
  • Create corrupted version (noise or counterfactual)
  • Define intervention config for each (layer, position)
  • Run patching sweep
  • Identify causal hotspots in heatmap
  • 准备包含目标事实关联的干净提示词
  • 创建被破坏的版本(噪声或反事实输入)
  • 为每个(层,位置)对定义干预配置
  • 运行补丁扫描
  • 在热力图中识别因果热点

Workflow 2: Activation Patching for Circuit Analysis

工作流2:用于电路分析的激活补丁

Test which components are necessary for a specific behavior.
测试哪些组件对特定行为是必要的。

Step-by-Step

分步指南

python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

IOI task setup

IOI任务设置

clean_prompt = "When John and Mary went to the store, Mary gave a bottle to" corrupted_prompt = "When John and Mary went to the store, John gave a bottle to"
clean_tokens = tokenizer(clean_prompt, return_tensors="pt") corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
john_token = tokenizer.encode(" John")[0] mary_token = tokenizer.encode(" Mary")[0]
def logit_diff(logits): """IO - S logit difference.""" return logits[0, -1, john_token] - logits[0, -1, mary_token]
clean_prompt = "When John and Mary went to the store, Mary gave a bottle to" corrupted_prompt = "When John and Mary went to the store, John gave a bottle to"
clean_tokens = tokenizer(clean_prompt, return_tensors="pt") corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
john_token = tokenizer.encode(" John")[0] mary_token = tokenizer.encode(" Mary")[0]
def logit_diff(logits): """IO - S logit差值。""" return logits[0, -1, john_token] - logits[0, -1, mary_token]

Patch attention output at each layer

对每层的注意力输出进行补丁

def patch_attention(layer): config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=layer, component="attention_output", intervention_type=pv.VanillaIntervention, ) ] )
intervenable = pv.IntervenableModel(config, model)

_, patched_outputs = intervenable(
    base=corrupted_tokens,
    sources=[clean_tokens],
)

return logit_diff(patched_outputs.logits).item()
def patch_attention(layer): config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=layer, component="attention_output", intervention_type=pv.VanillaIntervention, ) ] )
intervenable = pv.IntervenableModel(config, model)

_, patched_outputs = intervenable(
    base=corrupted_tokens,
    sources=[clean_tokens],
)

return logit_diff(patched_outputs.logits).item()

Find which layers matter

找出关键层

results = [] for layer in range(model.config.n_layer): diff = patch_attention(layer) results.append(diff) print(f"Layer {layer}: logit diff = {diff:.3f}")
undefined
results = [] for layer in range(model.config.n_layer): diff = patch_attention(layer) results.append(diff) print(f"Layer {layer}: logit diff = {diff:.3f}")
undefined

Workflow 3: Interchange Intervention Training (IIT)

工作流3:互换干预训练(IIT)

Train interventions to discover causal structure.
训练干预以发现因果结构。

Step-by-Step

分步指南

python
import pyvene as pv
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")
python
import pyvene as pv
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")

1. Define trainable intervention

1. 定义可训练干预

config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=6, component="block_output", intervention_type=pv.RotatedSpaceIntervention, # Trainable low_rank_dimension=64, # Learn 64-dim subspace ) ] )
intervenable = pv.IntervenableModel(config, model)
config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=6, component="block_output", intervention_type=pv.RotatedSpaceIntervention, # 可训练 low_rank_dimension=64, # 学习64维子空间 ) ] )
intervenable = pv.IntervenableModel(config, model)

2. Set up training

2. 设置训练

optimizer = torch.optim.Adam( intervenable.get_trainable_parameters(), lr=1e-4 )
optimizer = torch.optim.Adam( intervenable.get_trainable_parameters(), lr=1e-4 )

3. Training loop (simplified)

3. 训练循环(简化版)

for base_input, source_input, target_output in dataloader: optimizer.zero_grad()
_, outputs = intervenable(
    base=base_input,
    sources=[source_input],
)

loss = criterion(outputs.logits, target_output)
loss.backward()
optimizer.step()
for base_input, source_input, target_output in dataloader: optimizer.zero_grad()
_, outputs = intervenable(
    base=base_input,
    sources=[source_input],
)

loss = criterion(outputs.logits, target_output)
loss.backward()
optimizer.step()

4. Analyze learned intervention

4. 分析学到的干预

The rotation matrix reveals causal subspace

旋转矩阵揭示了因果子空间

rotation = intervenable.interventions["layer.6.block_output"][0].rotate_layer
undefined
rotation = intervenable.interventions["layer.6.block_output"][0].rotate_layer
undefined

DAS (Distributed Alignment Search)

DAS(分布式对齐搜索)

python
undefined
python
undefined

Low-rank rotation finds interpretable subspaces

低秩旋转寻找可解释的子空间

config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=8, component="block_output", intervention_type=pv.LowRankRotatedSpaceIntervention, low_rank_dimension=1, # Find 1D causal direction ) ] )
undefined
config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=8, component="block_output", intervention_type=pv.LowRankRotatedSpaceIntervention, low_rank_dimension=1, # 寻找1维因果方向 ) ] )
undefined

Workflow 4: Model Steering (Honest LLaMA)

工作流4:模型引导(Honest LLaMA)

Steer model behavior during generation.
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
在生成过程中引导模型行为。
python
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

Load pre-trained steering intervention

加载预训练的引导干预

intervenable = pv.IntervenableModel.load( "zhengxuanzenwu/intervenable_honest_llama2_chat_7B", model=model, )
intervenable = pv.IntervenableModel.load( "zhengxuanzenwu/intervenable_honest_llama2_chat_7B", model=model, )

Generate with steering

带引导的生成

prompt = "Is the earth flat?" inputs = tokenizer(prompt, return_tensors="pt")
prompt = "Is the earth flat?" inputs = tokenizer(prompt, return_tensors="pt")

Intervention applied during generation

生成过程中应用干预

outputs = intervenable.generate( inputs, max_new_tokens=100, do_sample=False, )
print(tokenizer.decode(outputs[0]))
undefined
outputs = intervenable.generate( inputs, max_new_tokens=100, do_sample=False, )
print(tokenizer.decode(outputs[0]))
undefined

Saving and Sharing Interventions

保存与分享干预

python
undefined
python
undefined

Save locally

本地保存

intervenable.save("./my_intervention")
intervenable.save("./my_intervention")

Load from local

从本地加载

intervenable = pv.IntervenableModel.load( "./my_intervention", model=model, )
intervenable = pv.IntervenableModel.load( "./my_intervention", model=model, )

Share on HuggingFace

在HuggingFace上分享

intervenable.save_intervention("username/my-intervention")
intervenable.save_intervention("username/my-intervention")

Load from HuggingFace

从HuggingFace加载

intervenable = pv.IntervenableModel.load( "username/my-intervention", model=model, )
undefined
intervenable = pv.IntervenableModel.load( "username/my-intervention", model=model, )
undefined

Common Issues & Solutions

常见问题与解决方案

Issue: Wrong intervention location

问题:干预位置错误

python
undefined
python
undefined

WRONG: Incorrect component name

错误:组件名称不正确

config = pv.RepresentationConfig( component="mlp", # Not valid! )
config = pv.RepresentationConfig( component="mlp", # 无效! )

RIGHT: Use exact component name

正确:使用精确的组件名称

config = pv.RepresentationConfig( component="mlp_output", # Valid )
undefined
config = pv.RepresentationConfig( component="mlp_output", # 有效 )
undefined

Issue: Dimension mismatch

问题:维度不匹配

python
undefined
python
undefined

Ensure source and base have compatible shapes

确保源输入和基础输入的形状兼容

For position-specific interventions:

针对特定位置的干预:

config = pv.RepresentationConfig( unit="pos", max_number_of_units=1, # Intervene on single position )
config = pv.RepresentationConfig( unit="pos", max_number_of_units=1, # 对单个位置进行干预 )

Specify locations explicitly

显式指定位置

intervenable( base=base_tokens, sources=[source_tokens], unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # Position 5 )
undefined
intervenable( base=base_tokens, sources=[source_tokens], unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # 位置5 )
undefined

Issue: Memory with large models

问题:大模型内存不足

python
undefined
python
undefined

Use gradient checkpointing

使用梯度检查点

model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable()

Or intervene on fewer components

或者减少干预组件数量

config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=8, # Single layer instead of all component="block_output", ) ] )
undefined
config = pv.IntervenableConfig( representations=[ pv.RepresentationConfig( layer=8, # 仅干预单个层而非所有层 component="block_output", ) ] )
undefined

Issue: LoRA integration

问题:LoRA集成

python
undefined
python
undefined

pyvene v0.1.8+ supports LoRAs as interventions

pyvene v0.1.8+支持将LoRA作为干预

config = pv.RepresentationConfig( intervention_type=pv.LoRAIntervention, low_rank_dimension=16, )
undefined
config = pv.RepresentationConfig( intervention_type=pv.LoRAIntervention, low_rank_dimension=16, )
undefined

Key Classes Reference

核心类参考

ClassPurpose
IntervenableModel
Main wrapper for interventions
IntervenableConfig
Configuration container
RepresentationConfig
Single intervention specification
VanillaIntervention
Activation swapping
RotatedSpaceIntervention
Trainable DAS intervention
CollectIntervention
Activation collection
用途
IntervenableModel
干预功能的主包装类
IntervenableConfig
配置容器
RepresentationConfig
单个干预的规格说明
VanillaIntervention
激活值交换
RotatedSpaceIntervention
可训练的DAS干预
CollectIntervention
激活值收集

Supported Models

支持的模型

pyvene works with any PyTorch model. Tested on:
  • GPT-2 (all sizes)
  • LLaMA / LLaMA-2
  • Pythia
  • Mistral / Mixtral
  • OPT
  • BLIP (vision-language)
  • ESM (protein models)
  • Mamba (state space)
pyvene可用于任意PyTorch模型,已测试的模型包括:
  • GPT-2(全尺寸)
  • LLaMA / LLaMA-2
  • Pythia
  • Mistral / Mixtral
  • OPT
  • BLIP(视觉-语言模型)
  • ESM(蛋白质模型)
  • Mamba(状态空间模型)

Reference Documentation

参考文档

For detailed API documentation, tutorials, and advanced usage, see the
references/
folder:
FileContents
references/README.mdOverview and quick start guide
references/api.mdComplete API reference for IntervenableModel, intervention types, configurations
references/tutorials.mdStep-by-step tutorials for causal tracing, activation patching, DAS
如需详细的API文档、教程和高级用法,请查看
references/
文件夹:
文件内容
references/README.md概览与快速入门指南
references/api.mdIntervenableModel、干预类型、配置的完整API参考
references/tutorials.md因果追踪、激活补丁、DAS的分步教程

External Resources

与其他工具的对比

Tutorials

特性pyveneTransformerLensnnsight
声明式配置
HuggingFace分享
可训练干预有限支持
支持任意PyTorch模型仅支持Transformer
远程执行是(NDIF)

Papers

Official Documentation

Comparison with Other Tools

FeaturepyveneTransformerLensnnsight
Declarative configYesNoNo
HuggingFace sharingYesNoNo
Trainable interventionsYesLimitedYes
Any PyTorch modelYesTransformers onlyYes
Remote executionNoNoYes (NDIF)