model-equivariance-auditor

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Model Equivariance Auditor

模型等变性审计工具

What Is It?

什么是模型等变性审计工具?

This skill helps you verify that your implemented model correctly respects its intended symmetries. Even with equivariant libraries, implementation bugs can break equivariance. This skill provides systematic verification tests and debugging strategies.
Why audit? A model that claims equivariance but isn't will train poorly and give inconsistent predictions. Catching these bugs early saves debugging time.
该工具可帮助你验证已实现的模型是否正确遵循其预期的对称性。即使使用等变库,实现bug也可能破坏等变性。此工具提供系统化的验证测试和调试策略。
为什么要进行审计? 声称具有等变性但实际不具备的模型会训练效果不佳,且预测结果不一致。尽早发现这些bug能节省调试时间。

Workflow

工作流程

Copy this checklist and track your progress:
Equivariance Audit Progress:
- [ ] Step 1: Gather model and symmetry specification
- [ ] Step 2: Run numerical equivariance tests
- [ ] Step 3: Test individual layers
- [ ] Step 4: Check gradient equivariance
- [ ] Step 5: Identify and diagnose failures
- [ ] Step 6: Document audit results
Step 1: Gather model and symmetry specification
Collect: the implemented model, the intended symmetry group, whether each output should be invariant or equivariant, the transformation functions for input and output spaces. Review the architecture specification from design phase. Clarify ambiguities with user before testing.
Step 2: Run numerical equivariance tests
Execute end-to-end equivariance tests using Test Implementation. For invariance: verify ||f(T(x)) - f(x)|| < ε. For equivariance: verify ||f(T(x)) - T'(f(x))|| < ε. Use multiple random inputs and transformations. Record error statistics. See Error Interpretation for thresholds. For ready-to-use test code, see Test Code Templates.
Step 3: Test individual layers
If end-to-end test fails, isolate the problem by testing layers individually. For each layer: freeze other layers, test equivariance of that layer alone. This identifies which layer breaks equivariance. Use Layer-wise Testing protocol. Check nonlinearities, normalizations, and custom operations especially carefully.
Step 4: Check gradient equivariance
Verify that gradients also respect equivariance (important for training). Compute gradients at x and T(x). Check that gradients transform appropriately. Gradient bugs can cause training to "unlearn" equivariance. See Gradient Testing.
Step 5: Identify and diagnose failures
If tests fail, use Common Failure Modes to diagnose. Check: non-equivariant nonlinearities, batch normalization issues, incorrect output transformation, numerical precision problems, implementation bugs in custom layers. Provide specific fix recommendations. For step-by-step troubleshooting, consult Debugging Guide.
Step 6: Document audit results
Create audit report using Output Template. Include: pass/fail for each test, error magnitudes, identified issues, and recommendations. Distinguish between: exact equivariance (numerical precision), approximate equivariance (acceptable error), and broken equivariance (needs fixing). For detailed audit methodology, see Methodology Details. Quality criteria for this output are defined in Quality Rubric.
复制以下检查清单并跟踪进度:
等变性审计进度:
- [ ] 步骤1:收集模型和对称性规范
- [ ] 步骤2:运行数值等变性测试
- [ ] 步骤3:测试单个层
- [ ] 步骤4:检查梯度等变性
- [ ] 步骤5:识别并诊断故障
- [ ] 步骤6:记录审计结果
步骤1:收集模型和对称性规范
收集以下内容:已实现的模型、预期的对称群、每个输出应是不变的还是等变的、输入和输出空间的变换函数。回顾设计阶段的架构规范。测试前与用户明确模糊点。
步骤2:运行数值等变性测试
使用测试实现执行端到端等变性测试。对于不变性:验证||f(T(x)) - f(x)|| < ε。对于等变性:验证||f(T(x)) - T'(f(x))|| < ε。使用多个随机输入和变换。记录误差统计数据。有关阈值请参见误差解读。如需即用型测试代码,请参见测试代码模板
步骤3:测试单个层
如果端到端测试失败,通过单独测试各个层来隔离问题。对于每个层:冻结其他层,单独测试该层的等变性。这能确定哪个层破坏了等变性。使用逐层测试协议。尤其要仔细检查非线性、归一化和自定义操作。
步骤4:检查梯度等变性
验证梯度是否也遵循等变性(这对训练很重要)。计算x和T(x)处的梯度。检查梯度是否进行了适当的变换。梯度bug会导致模型在训练中“遗忘”等变性。请参见梯度测试
步骤5:识别并诊断故障
如果测试失败,使用常见故障模式进行诊断。检查:非等变非线性、批量归一化问题、输出变换错误、数值精度问题、自定义层中的实现bug。提供具体的修复建议。如需分步排查,请参考调试指南
步骤6:记录审计结果
使用输出模板创建审计报告。包括:每个测试的通过/失败情况、误差幅度、已识别的问题和建议。区分:精确等变性(数值精度范围内)、近似等变性(可接受误差)和破坏的等变性(需要修复)。如需详细的审计方法,请参见方法细节。此输出的质量标准定义在质量评估准则中。

Test Implementation

测试实现

End-to-End Equivariance Test

端到端等变性测试

python
import torch

def test_model_equivariance(model, x, input_transform, output_transform,
                            n_tests=100, tol=1e-5):
    """
    Test if model is equivariant: f(T(x)) ≈ T'(f(x))

    Args:
        model: The neural network to test
        x: Sample input tensor
        input_transform: Function that transforms input
        output_transform: Function that transforms output
        n_tests: Number of random transformations to test
        tol: Error tolerance

    Returns:
        dict with test results
    """
    model.eval()
    errors = []

    with torch.no_grad():
        for _ in range(n_tests):
            # Generate random transformation
            T = sample_random_transform()

            # Method 1: Transform input, then apply model
            x_transformed = input_transform(x, T)
            y1 = model(x_transformed)

            # Method 2: Apply model, then transform output
            y = model(x)
            y2 = output_transform(y, T)

            # Compute error
            error = torch.norm(y1 - y2).item()
            relative_error = error / (torch.norm(y2).item() + 1e-8)
            errors.append({
                'absolute': error,
                'relative': relative_error
            })

    return {
        'mean_absolute': np.mean([e['absolute'] for e in errors]),
        'max_absolute': np.max([e['absolute'] for e in errors]),
        'mean_relative': np.mean([e['relative'] for e in errors]),
        'max_relative': np.max([e['relative'] for e in errors]),
        'pass': all(e['relative'] < tol for e in errors)
    }
python
import torch

def test_model_equivariance(model, x, input_transform, output_transform,
                            n_tests=100, tol=1e-5):
    """
    Test if model is equivariant: f(T(x)) ≈ T'(f(x))

    Args:
        model: The neural network to test
        x: Sample input tensor
        input_transform: Function that transforms input
        output_transform: Function that transforms output
        n_tests: Number of random transformations to test
        tol: Error tolerance

    Returns:
        dict with test results
    """
    model.eval()
    errors = []

    with torch.no_grad():
        for _ in range(n_tests):
            # Generate random transformation
            T = sample_random_transform()

            # Method 1: Transform input, then apply model
            x_transformed = input_transform(x, T)
            y1 = model(x_transformed)

            # Method 2: Apply model, then transform output
            y = model(x)
            y2 = output_transform(y, T)

            # Compute error
            error = torch.norm(y1 - y2).item()
            relative_error = error / (torch.norm(y2).item() + 1e-8)
            errors.append({
                'absolute': error,
                'relative': relative_error
            })

    return {
        'mean_absolute': np.mean([e['absolute'] for e in errors]),
        'max_absolute': np.max([e['absolute'] for e in errors]),
        'mean_relative': np.mean([e['relative'] for e in errors]),
        'max_relative': np.max([e['relative'] for e in errors]),
        'pass': all(e['relative'] < tol for e in errors)
    }

Invariance Test (Simpler Case)

不变性测试(简单案例)

python
def test_model_invariance(model, x, transform, n_tests=100, tol=1e-5):
    """Test if model output is invariant to transformations."""
    model.eval()
    errors = []

    with torch.no_grad():
        y_original = model(x)

        for _ in range(n_tests):
            T = sample_random_transform()
            x_transformed = transform(x, T)
            y_transformed = model(x_transformed)

            error = torch.norm(y_transformed - y_original).item()
            errors.append(error)

    return {
        'mean_error': np.mean(errors),
        'max_error': np.max(errors),
        'pass': max(errors) < tol
    }
python
def test_model_invariance(model, x, transform, n_tests=100, tol=1e-5):
    """Test if model output is invariant to transformations."""
    model.eval()
    errors = []

    with torch.no_grad():
        y_original = model(x)

        for _ in range(n_tests):
            T = sample_random_transform()
            x_transformed = transform(x, T)
            y_transformed = model(x_transformed)

            error = torch.norm(y_transformed - y_original).item()
            errors.append(error)

    return {
        'mean_error': np.mean(errors),
        'max_error': np.max(errors),
        'pass': max(errors) < tol
    }

Layer-wise Testing

逐层测试

Protocol

协议

python
def test_layer_equivariance(layer, x, input_transform, output_transform):
    """Test a single layer for equivariance."""
    layer.eval()

    with torch.no_grad():
        T = sample_random_transform()

        # Transform then layer
        y1 = layer(input_transform(x, T))

        # Layer then transform
        y2 = output_transform(layer(x), T)

        error = torch.norm(y1 - y2).item()

    return {
        'layer': layer.__class__.__name__,
        'error': error,
        'pass': error < tolerance
    }

def audit_all_layers(model, x, transforms):
    """Test each layer individually."""
    results = []

    for name, layer in model.named_modules():
        if is_testable_layer(layer):
            result = test_layer_equivariance(layer, x, *transforms)
            result['name'] = name
            results.append(result)

    return results
python
def test_layer_equivariance(layer, x, input_transform, output_transform):
    """Test a single layer for equivariance."""
    layer.eval()

    with torch.no_grad():
        T = sample_random_transform()

        # Transform then layer
        y1 = layer(input_transform(x, T))

        # Layer then transform
        y2 = output_transform(layer(x), T)

        error = torch.norm(y1 - y2).item()

    return {
        'layer': layer.__class__.__name__,
        'error': error,
        'pass': error < tolerance
    }

def audit_all_layers(model, x, transforms):
    """Test each layer individually."""
    results = []

    for name, layer in model.named_modules():
        if is_testable_layer(layer):
            result = test_layer_equivariance(layer, x, *transforms)
            result['name'] = name
            results.append(result)

    return results

What to Test Per Layer

各层测试要点

Layer TypeWhat to Check
ConvolutionKernel equivariance
NonlinearityShould preserve equivariance
NormalizationOften breaks equivariance
PoolingCorrect aggregation
LinearWeight sharing patterns
AttentionPermutation equivariance
层类型检查内容
卷积层核等变性
非线性层应保持等变性
归一化层常破坏等变性
池化层正确聚合
线性层权重共享模式
注意力层置换等变性

Gradient Testing

梯度测试

Why Test Gradients?

为什么要测试梯度?

Forward pass can be equivariant while backward pass is not. This causes:
  • Training instability
  • Model "unlearning" equivariance
  • Inconsistent optimization
前向传播可能具有等变性,但反向传播可能不具备。这会导致:
  • 训练不稳定
  • 模型“遗忘”等变性
  • 优化不一致

Gradient Equivariance Test

梯度等变性测试

python
def test_gradient_equivariance(model, x, loss_fn, transform, tol=1e-4):
    """Test if gradients respect equivariance."""
    model.train()

    # Gradients at original input
    x1 = x.clone().requires_grad_(True)
    y1 = model(x1)
    loss1 = loss_fn(y1)
    loss1.backward()
    grad1 = x1.grad.clone()

    # Gradients at transformed input
    model.zero_grad()
    T = sample_random_transform()
    x2 = transform(x.clone(), T).requires_grad_(True)
    y2 = model(x2)
    loss2 = loss_fn(y2)
    loss2.backward()
    grad2 = x2.grad.clone()

    # Transform grad1 and compare to grad2
    grad1_transformed = transform_gradient(grad1, T)
    error = torch.norm(grad2 - grad1_transformed).item()

    return {'error': error, 'pass': error < tol}
python
def test_gradient_equivariance(model, x, loss_fn, transform, tol=1e-4):
    """Test if gradients respect equivariance."""
    model.train()

    # Gradients at original input
    x1 = x.clone().requires_grad_(True)
    y1 = model(x1)
    loss1 = loss_fn(y1)
    loss1.backward()
    grad1 = x1.grad.clone()

    # Gradients at transformed input
    model.zero_grad()
    T = sample_random_transform()
    x2 = transform(x.clone(), T).requires_grad_(True)
    y2 = model(x2)
    loss2 = loss_fn(y2)
    loss2.backward()
    grad2 = x2.grad.clone()

    # Transform grad1 and compare to grad2
    grad1_transformed = transform_gradient(grad1, T)
    error = torch.norm(grad2 - grad1_transformed).item()

    return {'error': error, 'pass': error < tol}

Error Interpretation

误差解读

Error Thresholds

误差阈值

Error LevelInterpretationAction
< 1e-6Perfect (float32 precision)Pass
1e-6 to 1e-4Excellent (acceptable)Pass
1e-4 to 1e-2Approximate equivarianceInvestigate
> 1e-2Broken equivarianceFix required
误差级别解读操作
< 1e-6完美(float32精度范围内)通过
1e-6 至 1e-4优秀(可接受)通过
1e-4 至 1e-2近似等变性调查原因
> 1e-2等变性被破坏需要修复

Relative vs Absolute Error

相对误差与绝对误差

  • Absolute error: Raw difference magnitude
  • Relative error: Normalized by output magnitude
Use relative error when output magnitudes vary. Use absolute when comparing to numerical precision.
  • 绝对误差:原始差异的幅度
  • 相对误差:按输出幅度归一化
当输出幅度变化时使用相对误差。与数值精度比较时使用绝对误差。

Common Failure Modes

常见故障模式

1. Non-Equivariant Nonlinearity

1. 非等变非线性

Symptom: Error increases after nonlinearity layers Cause: Using ReLU, sigmoid on equivariant features Fix: Use gated nonlinearities, norm-based, or restrict to invariant features
症状:非线性层后误差增加 原因:在等变特征上使用ReLU、sigmoid 修复:使用门控非线性、基于范数的非线性,或限制在不变特征上

2. Batch Normalization Breaking Equivariance

2. 批量归一化破坏等变性

Symptom: Error varies with batch composition Cause: BN computes different stats for different orientations Fix: Use LayerNorm, GroupNorm, or equivariant batch norm
症状:误差随批次组成变化 原因:BN对不同方向计算不同的统计量 修复:使用LayerNorm、GroupNorm或等变批量归一化

3. Incorrect Output Transformation

3. 输出变换错误

Symptom: Test fails even for identity transform Cause: output_transform doesn't match model output type Fix: Verify output transformation matches layer output representation
症状:即使是恒等变换测试也失败 原因:output_transform与模型输出类型不匹配 修复:验证输出变换是否与层输出表示匹配

4. Numerical Precision Issues

4. 数值精度问题

Symptom: Small but non-zero error everywhere Cause: Floating point accumulation, interpolation Fix: Use float64 for testing, accept small tolerance
症状:各处都有小但非零的误差 原因:浮点累加、插值 修复:测试时使用float64,接受小的容差

5. Custom Layer Bug

5. 自定义层Bug

Symptom: Error isolated to specific layer Cause: Implementation error in custom equivariant layer Fix: Review layer implementation against equivariance constraints
症状:误差集中在特定层 原因:自定义等变层中的实现错误 修复:根据等变性约束复查层的实现

6. Padding/Boundary Effects

6. 填充/边界效应

Symptom: Error higher near edges Cause: Padding doesn't respect symmetry Fix: Use circular padding or handle boundaries explicitly
症状:边缘附近误差更高 原因:填充不遵循对称性 修复:使用循环填充或显式处理边界

Output Template

输出模板

MODEL EQUIVARIANCE AUDIT REPORT
===============================

Model: [Model name/description]
Intended Symmetry: [Group]
Symmetry Type: [Invariant/Equivariant]

END-TO-END TESTS:
-----------------
Test samples: [N]
Transformations tested: [M]

Invariance/Equivariance Error:
- Mean absolute: [value]
- Max absolute: [value]
- Mean relative: [value]
- Max relative: [value]
- RESULT: [PASS/FAIL]

LAYER-WISE ANALYSIS:
--------------------
[For each layer]
- Layer: [name]
- Error: [value]
- Result: [PASS/FAIL]

GRADIENT TEST:
--------------
- Gradient equivariance error: [value]
- RESULT: [PASS/FAIL]

IDENTIFIED ISSUES:
------------------
1. [Issue description]
   - Location: [layer/component]
   - Severity: [High/Medium/Low]
   - Recommended fix: [description]

OVERALL VERDICT: [PASS/FAIL/NEEDS_ATTENTION]

Recommendations:
- [List of actions needed]
MODEL EQUIVARIANCE AUDIT REPORT
===============================

Model: [Model name/description]
Intended Symmetry: [Group]
Symmetry Type: [Invariant/Equivariant]

END-TO-END TESTS:
-----------------
Test samples: [N]
Transformations tested: [M]

Invariance/Equivariance Error:
- Mean absolute: [value]
- Max absolute: [value]
- Mean relative: [value]
- Max relative: [value]
- RESULT: [PASS/FAIL]

LAYER-WISE ANALYSIS:
--------------------
[For each layer]
- Layer: [name]
- Error: [value]
- Result: [PASS/FAIL]

GRADIENT TEST:
--------------
- Gradient equivariance error: [value]
- RESULT: [PASS/FAIL]

IDENTIFIED ISSUES:
------------------
1. [Issue description]
   - Location: [layer/component]
   - Severity: [High/Medium/Low]
   - Recommended fix: [description]

OVERALL VERDICT: [PASS/FAIL/NEEDS_ATTENTION]

Recommendations:
- [List of actions needed]