Loading...
Loading...
Guidance for recovering PyTorch model architectures from state dictionaries, retraining specific layers, and saving models in TorchScript format. This skill should be used when tasks involve reconstructing model architectures from saved weights, fine-tuning specific layers while freezing others, or converting models to TorchScript format.
npx skill4agent add letta-ai/skills pytorch-model-recovery.pt.pthimport torch
weights = torch.load('model_weights.pt', map_location='cpu')
# Print all keys with shapes
for key, value in weights.items():
print(f"{key}: {value.shape}")| Key Pattern | Indicates |
|---|---|
| Transformer encoder with N+1 layers |
| Transformer decoder with N+1 layers |
| Embedding layer |
| Positional encoding (often a buffer) |
| Final linear projection |
| Combined QKV projection in attention |
| Self-attention component |
| Feed-forward network layers |
| Layer normalization |
# Example: Inferring transformer dimensions
d_model = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[1]
nhead = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) * nhead_factor
# Note: in_proj_weight has shape [3*d_model, d_model] for combined QKV
vocab_size = weights['embedding.weight'].shape[0]
num_layers = max(int(k.split('.')[2]) for k in weights if 'encoder.layers' in k) + 1nn.TransformerEncoderclass RecoveredModel(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward):
super().__init__()
# Ensure attribute names match state dict keys exactly
self.embedding = nn.Embedding(vocab_size, d_model)
# For positional encoding stored as buffer
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True # Check if original used batch_first
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)self.fcself.output_layernn.MultiheadAttentionbatch_first=Truebatch_first=Falsemodel = RecoveredModel(...)
# This will raise an error if keys don't match
model.load_state_dict(weights, strict=True)
print("Weights loaded successfully!")
# Verify a forward pass works
with torch.no_grad():
dummy_input = torch.randint(0, vocab_size, (1, 10))
output = model(dummy_input)
print(f"Output shape: {output.shape}")load_state_dictmodel_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())
missing = weight_keys - model_keys
unexpected = model_keys - weight_keys
print(f"Missing in model: {missing}")
print(f"Unexpected in model: {unexpected}")# Test scripting works before investing time in training
try:
scripted = torch.jit.script(model)
print("TorchScript scripting successful")
except Exception as e:
print(f"Scripting failed: {e}")
# Try tracing instead
traced = torch.jit.trace(model, dummy_input)
print("TorchScript tracing successful")# Freeze all parameters first
for param in model.parameters():
param.requires_grad = False
# Unfreeze only target layers
for param in model.output_layer.parameters():
param.requires_grad = True
# Verify freeze status
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} parameters")model.eval()
with torch.no_grad():
outputs = model(inputs)
original_loss = criterion(outputs, targets)
print(f"Original MSE loss: {original_loss.item()}")# Create optimizer only for trainable parameters
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=0.001
)
# Training with progress tracking
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.6f}")# Pre-compute frozen layer outputs
model.eval()
with torch.no_grad():
# Get features before output layer
features = model.get_features(inputs) # Shape: [N, d_model]
# Solve linear regression: W*features = targets
# Using pseudo-inverse: W = targets @ features.T @ (features @ features.T)^-1
solution = torch.linalg.lstsq(features, targets).solution
model.output_layer.weight.data = solution.T# Ensure model is in eval mode
model.eval()
# Script the model (preferred for control flow)
scripted_model = torch.jit.script(model)
scripted_model.save('/app/model.pt')
# Or trace the model (for simpler models)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('/app/model.pt')# Reload and verify
loaded = torch.jit.load('/app/model.pt')
loaded.eval()
with torch.no_grad():
original_out = model(test_input)
loaded_out = loaded(test_input)
diff = (original_out - loaded_out).abs().max()
print(f"Max difference: {diff.item()}")
assert diff < 1e-5, "Model outputs don't match!"import time
start = time.time()
_ = model(torch.randint(0, vocab_size, (1, 10)))
print(f"Single forward pass: {time.time() - start:.2f}s")# Clear GPU cache between operations
torch.cuda.empty_cache()
# Use gradient checkpointing for large models
from torch.utils.checkpoint import checkpoint
# Process in smaller batches
for batch in torch.split(data, batch_size):
process(batch)load_state_dicttorch.no_grad()model.eval()load_state_dict