pytorch-model-recovery
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChinesePyTorch Model Recovery
PyTorch模型恢复
This skill provides guidance for tasks involving PyTorch model architecture recovery from state dictionaries, selective layer training, and TorchScript export.
本技能提供以下任务的指导:从状态字典恢复PyTorch模型架构、选择性层训练,以及TorchScript导出。
When to Use This Skill
适用场景
This skill applies when:
- Reconstructing a model architecture from a state dictionary (or
.ptfile containing weights).pth - Training or fine-tuning specific layers while keeping others frozen
- Converting a recovered model to TorchScript format
- Debugging model loading issues or architecture mismatches
本技能适用于以下情况:
- 从状态字典(包含权重的或
.pt文件)重建模型架构.pth - 训练或微调特定层,同时冻结其他层
- 将恢复后的模型转换为TorchScript格式
- 调试模型加载问题或架构不匹配问题
Approach Overview
方法概述
Model recovery tasks require a systematic, incremental approach with verification at each step. The key phases are:
- Architecture Analysis - Infer model structure from state dictionary keys
- Architecture Implementation - Build the model class to match the state dict
- Verification - Confirm weights load correctly before any training
- Training - Fine-tune specific layers with appropriate hyperparameters
- Export - Save to required format (often TorchScript)
模型恢复任务需要系统的、渐进式的方法,并在每个步骤进行验证。关键阶段如下:
- 架构分析 - 从状态字典键推断模型结构
- 架构实现 - 构建与状态字典匹配的模型类
- 验证 - 在训练前确认权重加载正确
- 训练 - 使用合适的超参数微调特定层
- 导出 - 将模型保存为所需格式(通常为TorchScript)
Phase 1: Architecture Analysis
阶段1:架构分析
Examining the State Dictionary
检查状态字典
To understand the model architecture, first load and inspect the state dictionary:
python
import torch
weights = torch.load('model_weights.pt', map_location='cpu')要了解模型架构,首先加载并检查状态字典:
python
import torch
weights = torch.load('model_weights.pt', map_location='cpu')Print all keys with shapes
打印所有键及其形状
for key, value in weights.items():
print(f"{key}: {value.shape}")
undefinedfor key, value in weights.items():
print(f"{key}: {value.shape}")
undefinedKey Patterns to Identify
需要识别的常见键模式
Common patterns in state dictionary keys:
| Key Pattern | Indicates |
|---|---|
| Transformer encoder with N+1 layers |
| Transformer decoder with N+1 layers |
| Embedding layer |
| Positional encoding (often a buffer) |
| Final linear projection |
| Combined QKV projection in attention |
| Self-attention component |
| Feed-forward network layers |
| Layer normalization |
状态字典键中的常见模式:
| 键模式 | 含义 |
|---|---|
| 包含N+1层的Transformer编码器 |
| 包含N+1层的Transformer解码器 |
| 嵌入层 |
| 位置编码(通常是缓冲区) |
| 最终线性投影层 |
| 注意力机制中的组合QKV投影 |
| 自注意力组件 |
| 前馈网络层 |
| 层归一化 |
Inferring Dimensions
推断维度
Extract model dimensions from weight shapes:
python
undefined从权重形状提取模型维度:
python
undefinedExample: Inferring transformer dimensions
示例:推断Transformer维度
d_model = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[1]
nhead = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) * nhead_factor
d_model = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[1]
nhead = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) * nhead_factor
Note: in_proj_weight has shape [3*d_model, d_model] for combined QKV
注意:in_proj_weight的形状为[3*d_model, d_model],对应组合式QKV
vocab_size = weights['embedding.weight'].shape[0]
num_layers = max(int(k.split('.')[2]) for k in weights if 'encoder.layers' in k) + 1
undefinedvocab_size = weights['embedding.weight'].shape[0]
num_layers = max(int(k.split('.')[2]) for k in weights if 'encoder.layers' in k) + 1
undefinedPhase 2: Architecture Implementation
阶段2:架构实现
Building the Model Class
构建模型类
When implementing the model class:
- Match the exact layer names used in the state dictionary
- Use the same PyTorch module types (e.g., vs custom)
nn.TransformerEncoder - Register buffers for non-learnable tensors (e.g., positional encodings)
python
class RecoveredModel(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
super().__init__()
# Ensure attribute names match state dict keys exactly
self.embedding = nn.Embedding(vocab_size, d_model)
# For positional encoding stored as buffer
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True # Check if original used batch_first
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)实现模型类时:
- 匹配状态字典中使用的精确层名称
- 使用相同的PyTorch模块类型(例如而非自定义模块)
nn.TransformerEncoder - 为不可学习的张量注册缓冲区(例如位置编码)
python
class RecoveredModel(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
super().__init__()
# 确保属性名称与状态字典键完全匹配
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码作为缓冲区注册
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True # 检查原模型是否使用batch_first
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)Common Architecture Mistakes
常见架构错误
- Incorrect layer naming: vs
self.fc- must match exactlyself.output_layer - Missing buffers: Positional encodings often registered as buffers, not parameters
- Wrong module types: Custom attention vs
nn.MultiheadAttention - Batch dimension mismatch: vs
batch_first=Truebatch_first=False
- 层名称不正确:与
self.fc必须完全匹配self.output_layer - 缺少缓冲区:位置编码通常注册为缓冲区,而非参数
- 模块类型错误:自定义注意力与混淆
nn.MultiheadAttention - 批量维度不匹配:与
batch_first=True设置错误batch_first=False
Phase 3: Verification (Critical)
阶段3:验证(关键步骤)
Verify Architecture Before Training
训练前验证架构
Always verify the model loads weights correctly before any training:
python
model = RecoveredModel(...)在进行任何训练前,务必验证模型是否能正确加载权重:
python
model = RecoveredModel(...)This will raise an error if keys don't match
如果键不匹配,此操作会抛出错误
model.load_state_dict(weights, strict=True)
print("Weights loaded successfully!")
model.load_state_dict(weights, strict=True)
print("权重加载成功!")
Verify a forward pass works
验证前向传播正常
with torch.no_grad():
dummy_input = torch.randint(0, vocab_size, (1, 10))
output = model(dummy_input)
print(f"Output shape: {output.shape}")
undefinedwith torch.no_grad():
dummy_input = torch.randint(0, vocab_size, (1, 10))
output = model(dummy_input)
print(f"输出形状: {output.shape}")
undefinedHandling Key Mismatches
处理键不匹配问题
If fails, compare keys:
load_state_dictpython
model_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())
missing = weight_keys - model_keys
unexpected = model_keys - weight_keys
print(f"Missing in model: {missing}")
print(f"Unexpected in model: {unexpected}")如果失败,对比键集合:
load_state_dictpython
model_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())
missing = weight_keys - model_keys
unexpected = model_keys - weight_keys
print(f"模型中缺少的键: {missing}")
print(f"模型中多余的键: {unexpected}")Verify TorchScript Compatibility Early
提前验证TorchScript兼容性
If TorchScript export is required, test it early:
python
undefined如果需要导出TorchScript,尽早进行测试:
python
undefinedTest scripting works before investing time in training
在投入训练时间前测试脚本化是否可行
try:
scripted = torch.jit.script(model)
print("TorchScript scripting successful")
except Exception as e:
print(f"Scripting failed: {e}")
# Try tracing instead
traced = torch.jit.trace(model, dummy_input)
print("TorchScript tracing successful")
undefinedtry:
scripted = torch.jit.script(model)
print("TorchScript脚本化成功")
except Exception as e:
print(f"脚本化失败: {e}")
# 尝试使用追踪方式
traced = torch.jit.trace(model, dummy_input)
print("TorchScript追踪成功")
undefinedPhase 4: Training Specific Layers
阶段4:训练特定层
Freezing Layers
冻结层
To train only specific layers, freeze all others:
python
undefined要仅训练特定层,先冻结所有其他层:
python
undefinedFreeze all parameters first
先冻结所有参数
for param in model.parameters():
param.requires_grad = False
for param in model.parameters():
param.requires_grad = False
Unfreeze only target layers
仅解冻目标层
for param in model.output_layer.parameters():
param.requires_grad = True
for param in model.output_layer.parameters():
param.requires_grad = True
Verify freeze status
验证冻结状态
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} parameters")
undefinedtrainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"可训练参数: {trainable:,} / {total:,}")
undefinedComputing Baseline Loss
计算基准损失
Before training, establish a baseline:
python
model.eval()
with torch.no_grad():
outputs = model(inputs)
original_loss = criterion(outputs, targets)
print(f"Original MSE loss: {original_loss.item()}")训练前,先建立基准:
python
model.eval()
with torch.no_grad():
outputs = model(inputs)
original_loss = criterion(outputs, targets)
print(f"原始MSE损失: {original_loss.item()}")Training Loop Considerations
训练循环注意事项
python
undefinedpython
undefinedCreate optimizer only for trainable parameters
仅为可训练参数创建优化器
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=0.001
)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=0.001
)
Training with progress tracking
带进度跟踪的训练
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.6f}")undefinedfor epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: 损失 = {loss.item():.6f}")undefinedAlternative: Closed-Form Solution for Linear Layers
替代方案:线性层的闭式解
When retraining only a linear output layer, consider a closed-form solution for efficiency:
python
undefined当仅重新训练线性输出层时,考虑使用闭式解以提高效率:
python
undefinedPre-compute frozen layer outputs
预计算冻结层的输出
model.eval()
with torch.no_grad():
# Get features before output layer
features = model.get_features(inputs) # Shape: [N, d_model]
model.eval()
with torch.no_grad():
# 获取输出层之前的特征
features = model.get_features(inputs) # 形状: [N, d_model]
Solve linear regression: W*features = targets
求解线性回归: W*features = targets
Using pseudo-inverse: W = targets @ features.T @ (features @ features.T)^-1
使用伪逆: W = targets @ features.T @ (features @ features.T)^-1
solution = torch.linalg.lstsq(features, targets).solution
model.output_layer.weight.data = solution.T
undefinedsolution = torch.linalg.lstsq(features, targets).solution
model.output_layer.weight.data = solution.T
undefinedPhase 5: TorchScript Export
阶段5:TorchScript导出
Saving the Model
保存模型
python
undefinedpython
undefinedEnsure model is in eval mode
确保模型处于eval模式
model.eval()
model.eval()
Script the model (preferred for control flow)
脚本化模型(推荐用于含控制流的模型)
scripted_model = torch.jit.script(model)
scripted_model.save('/app/model.pt')
scripted_model = torch.jit.script(model)
scripted_model.save('/app/model.pt')
Or trace the model (for simpler models)
或追踪模型(适用于简单模型)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('/app/model.pt')
undefinedtraced_model = torch.jit.trace(model, example_input)
traced_model.save('/app/model.pt')
undefinedVerify Saved Model
验证保存的模型
python
undefinedpython
undefinedReload and verify
重新加载并验证
loaded = torch.jit.load('/app/model.pt')
loaded.eval()
with torch.no_grad():
original_out = model(test_input)
loaded_out = loaded(test_input)
diff = (original_out - loaded_out).abs().max()
print(f"Max difference: {diff.item()}")
assert diff < 1e-5, "Model outputs don't match!"undefinedloaded = torch.jit.load('/app/model.pt')
loaded.eval()
with torch.no_grad():
original_out = model(test_input)
loaded_out = loaded(test_input)
diff = (original_out - loaded_out).abs().max()
print(f"最大差异: {diff.item()}")
assert diff < 1e-5, "模型输出不匹配!"undefinedEnvironment Considerations
环境注意事项
Handling Slow Environments
处理低性能环境
When operating in resource-constrained environments:
-
Benchmark first: Test basic operations before committing to full solutionpython
import time start = time.time() _ = model(torch.randint(0, vocab_size, (1, 10))) print(f"Single forward pass: {time.time() - start:.2f}s") -
Reduce batch size: Process samples individually if needed
-
Set realistic timeouts: Base on benchmarks, not arbitrary values
-
Use incremental checkpoints: Save progress periodically
在资源受限环境中操作时:
-
先基准测试:在投入完整解决方案前,测试基础操作python
import time start = time.time() _ = model(torch.randint(0, vocab_size, (1, 10))) print(f"单次前向传播耗时: {time.time() - start:.2f}s") -
减小批量大小:必要时逐个处理样本
-
设置合理超时:基于基准测试结果,而非任意值
-
使用增量检查点:定期保存进度
Memory Management
内存管理
python
undefinedpython
undefinedClear GPU cache between operations
操作之间清理GPU缓存
torch.cuda.empty_cache()
torch.cuda.empty_cache()
Use gradient checkpointing for large models
对大型模型使用梯度检查点
from torch.utils.checkpoint import checkpoint
from torch.utils.checkpoint import checkpoint
Process in smaller batches
分小批量处理
for batch in torch.split(data, batch_size):
process(batch)
undefinedfor batch in torch.split(data, batch_size):
process(batch)
undefinedCommon Pitfalls
常见陷阱
- Not verifying architecture match before training - Always test first
load_state_dict - Arbitrary hyperparameters - Justify choices based on task characteristics
- Ignoring TorchScript compatibility - Test export early, not after training
- Syntax errors in edits - Review code changes carefully, especially string formatting
- Incomplete state dict mapping - Verify all keys are accounted for
- Not establishing baseline metrics - Compute original loss before training
- Missing for inference - Use context manager for evaluation
torch.no_grad() - Forgetting to set - Required for consistent behavior in eval/export
model.eval()
- 训练前未验证架构匹配 - 务必先测试
load_state_dict - 超参数设置随意 - 根据任务特性调整并说明理由
- 忽略TorchScript兼容性 - 尽早测试导出,而非训练后
- 代码编辑中的语法错误 - 仔细检查代码变更,尤其是字符串格式化
- 状态字典映射不完整 - 确保所有键都已处理
- 未建立基准指标 - 训练前计算原始损失
- 推理时未使用- 评估时使用上下文管理器
torch.no_grad() - 忘记设置- 评估/导出时需要设置以保证行为一致
model.eval()
Verification Checklist
验证清单
Before considering the task complete:
- State dictionary keys fully analyzed and documented
- Model architecture matches state dict exactly (verified with )
load_state_dict - Forward pass produces valid output
- Baseline loss/metric computed
- Target layers correctly unfrozen, others frozen
- Training improves loss over baseline
- TorchScript export succeeds
- Exported model produces same outputs as original
- Model saved to required path
任务完成前需确认:
- 状态字典键已完整分析并记录
- 模型架构与状态字典完全匹配(通过验证)
load_state_dict - 前向传播生成有效输出
- 已计算基准损失/指标
- 目标层已正确解冻,其他层已冻结
- 训练后损失较基准有所改善
- TorchScript导出成功
- 导出模型与原模型输出一致
- 模型已保存至指定路径