Loading...
Loading...
Use when you have implemented an equivariant model and need to verify it correctly respects the intended symmetries. Invoke when user mentions testing model equivariance, debugging symmetry bugs, verifying implementation correctness, checking if model is actually equivariant, or diagnosing why equivariant model isn't working. Provides verification tests and debugging guidance.
npx skill4agent add lyndonkl/claude model-equivariance-auditorEquivariance Audit Progress:
- [ ] Step 1: Gather model and symmetry specification
- [ ] Step 2: Run numerical equivariance tests
- [ ] Step 3: Test individual layers
- [ ] Step 4: Check gradient equivariance
- [ ] Step 5: Identify and diagnose failures
- [ ] Step 6: Document audit resultsimport torch
def test_model_equivariance(model, x, input_transform, output_transform,
n_tests=100, tol=1e-5):
"""
Test if model is equivariant: f(T(x)) ≈ T'(f(x))
Args:
model: The neural network to test
x: Sample input tensor
input_transform: Function that transforms input
output_transform: Function that transforms output
n_tests: Number of random transformations to test
tol: Error tolerance
Returns:
dict with test results
"""
model.eval()
errors = []
with torch.no_grad():
for _ in range(n_tests):
# Generate random transformation
T = sample_random_transform()
# Method 1: Transform input, then apply model
x_transformed = input_transform(x, T)
y1 = model(x_transformed)
# Method 2: Apply model, then transform output
y = model(x)
y2 = output_transform(y, T)
# Compute error
error = torch.norm(y1 - y2).item()
relative_error = error / (torch.norm(y2).item() + 1e-8)
errors.append({
'absolute': error,
'relative': relative_error
})
return {
'mean_absolute': np.mean([e['absolute'] for e in errors]),
'max_absolute': np.max([e['absolute'] for e in errors]),
'mean_relative': np.mean([e['relative'] for e in errors]),
'max_relative': np.max([e['relative'] for e in errors]),
'pass': all(e['relative'] < tol for e in errors)
}def test_model_invariance(model, x, transform, n_tests=100, tol=1e-5):
"""Test if model output is invariant to transformations."""
model.eval()
errors = []
with torch.no_grad():
y_original = model(x)
for _ in range(n_tests):
T = sample_random_transform()
x_transformed = transform(x, T)
y_transformed = model(x_transformed)
error = torch.norm(y_transformed - y_original).item()
errors.append(error)
return {
'mean_error': np.mean(errors),
'max_error': np.max(errors),
'pass': max(errors) < tol
}def test_layer_equivariance(layer, x, input_transform, output_transform):
"""Test a single layer for equivariance."""
layer.eval()
with torch.no_grad():
T = sample_random_transform()
# Transform then layer
y1 = layer(input_transform(x, T))
# Layer then transform
y2 = output_transform(layer(x), T)
error = torch.norm(y1 - y2).item()
return {
'layer': layer.__class__.__name__,
'error': error,
'pass': error < tolerance
}
def audit_all_layers(model, x, transforms):
"""Test each layer individually."""
results = []
for name, layer in model.named_modules():
if is_testable_layer(layer):
result = test_layer_equivariance(layer, x, *transforms)
result['name'] = name
results.append(result)
return results| Layer Type | What to Check |
|---|---|
| Convolution | Kernel equivariance |
| Nonlinearity | Should preserve equivariance |
| Normalization | Often breaks equivariance |
| Pooling | Correct aggregation |
| Linear | Weight sharing patterns |
| Attention | Permutation equivariance |
def test_gradient_equivariance(model, x, loss_fn, transform, tol=1e-4):
"""Test if gradients respect equivariance."""
model.train()
# Gradients at original input
x1 = x.clone().requires_grad_(True)
y1 = model(x1)
loss1 = loss_fn(y1)
loss1.backward()
grad1 = x1.grad.clone()
# Gradients at transformed input
model.zero_grad()
T = sample_random_transform()
x2 = transform(x.clone(), T).requires_grad_(True)
y2 = model(x2)
loss2 = loss_fn(y2)
loss2.backward()
grad2 = x2.grad.clone()
# Transform grad1 and compare to grad2
grad1_transformed = transform_gradient(grad1, T)
error = torch.norm(grad2 - grad1_transformed).item()
return {'error': error, 'pass': error < tol}| Error Level | Interpretation | Action |
|---|---|---|
| < 1e-6 | Perfect (float32 precision) | Pass |
| 1e-6 to 1e-4 | Excellent (acceptable) | Pass |
| 1e-4 to 1e-2 | Approximate equivariance | Investigate |
| > 1e-2 | Broken equivariance | Fix required |
MODEL EQUIVARIANCE AUDIT REPORT
===============================
Model: [Model name/description]
Intended Symmetry: [Group]
Symmetry Type: [Invariant/Equivariant]
END-TO-END TESTS:
-----------------
Test samples: [N]
Transformations tested: [M]
Invariance/Equivariance Error:
- Mean absolute: [value]
- Max absolute: [value]
- Mean relative: [value]
- Max relative: [value]
- RESULT: [PASS/FAIL]
LAYER-WISE ANALYSIS:
--------------------
[For each layer]
- Layer: [name]
- Error: [value]
- Result: [PASS/FAIL]
GRADIENT TEST:
--------------
- Gradient equivariance error: [value]
- RESULT: [PASS/FAIL]
IDENTIFIED ISSUES:
------------------
1. [Issue description]
- Location: [layer/component]
- Severity: [High/Medium/Low]
- Recommended fix: [description]
OVERALL VERDICT: [PASS/FAIL/NEEDS_ATTENTION]
Recommendations:
- [List of actions needed]