model-pruning

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Model Pruning: Compressing LLMs

模型剪枝:压缩大语言模型(LLMs)

When to Use This Skill

何时使用该技术

Use Model Pruning when you need to:
  • Reduce model size by 40-60% with <1% accuracy loss
  • Accelerate inference using hardware-friendly sparsity (2-4× speedup)
  • Deploy on constrained hardware (mobile, edge devices)
  • Compress without retraining using one-shot methods
  • Enable efficient serving with reduced memory footprint
Key Techniques: Wanda (weights × activations), SparseGPT (second-order), structured pruning, N:M sparsity
Papers: Wanda ICLR 2024 (arXiv 2306.11695), SparseGPT (arXiv 2301.00774)
在以下场景中使用模型剪枝:
  • 减小模型体积:在精度损失<1%的前提下将模型体积缩小40-60%
  • 加速推理:利用硬件友好型稀疏度实现2-4倍推理速度提升
  • 部署到受限硬件(移动端、边缘设备)
  • 无需重新训练即可压缩:使用一次性剪枝方法
  • 实现高效部署:降低内存占用
核心技术:Wanda(权重×激活值)、SparseGPT(二阶方法)、结构化剪枝、N:M稀疏度
参考论文:Wanda ICLR 2024(arXiv 2306.11695)、SparseGPT(arXiv 2301.00774)

Installation

安装步骤

bash
undefined
bash
undefined

Wanda implementation

Wanda实现代码

git clone https://github.com/locuslab/wanda cd wanda pip install -r requirements.txt
git clone https://github.com/locuslab/wanda cd wanda pip install -r requirements.txt

Optional: SparseGPT

可选:SparseGPT

git clone https://github.com/IST-DASLab/sparsegpt cd sparsegpt pip install -e .
git clone https://github.com/IST-DASLab/sparsegpt cd sparsegpt pip install -e .

Dependencies

依赖库

pip install torch transformers accelerate
undefined
pip install torch transformers accelerate
undefined

Quick Start

快速开始

Wanda Pruning (One-Shot, No Retraining)

Wanda剪枝(一次性剪枝,无需重新训练)

Source: ICLR 2024 (arXiv 2306.11695)
python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
来源:ICLR 2024(arXiv 2306.11695)
python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

Load model

加载模型

model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="cuda" ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="cuda" ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

Calibration data (small dataset for activation statistics)

校准数据(用于统计激活值的小型数据集)

calib_data = [ "The quick brown fox jumps over the lazy dog.", "Machine learning is transforming the world.", "Artificial intelligence powers modern applications.", ]
calib_data = [ "The quick brown fox jumps over the lazy dog.", "Machine learning is transforming the world.", "Artificial intelligence powers modern applications.", ]

Wanda pruning function

Wanda剪枝函数

def wanda_prune(model, calib_data, sparsity=0.5): """ Wanda: Prune by weight magnitude × input activation.
Args:
    sparsity: Fraction of weights to prune (0.5 = 50%)
"""
# 1. Collect activation statistics
activations = {}

def hook_fn(name):
    def hook(module, input, output):
        # Store input activation norms
        activations[name] = input[0].detach().abs().mean(dim=0)
    return hook

# Register hooks for all linear layers
hooks = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        hooks.append(module.register_forward_hook(hook_fn(name)))

# Run calibration data
model.eval()
with torch.no_grad():
    for text in calib_data:
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        model(**inputs)

# Remove hooks
for hook in hooks:
    hook.remove()

# 2. Prune weights based on |weight| × activation
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and name in activations:
        W = module.weight.data
        act = activations[name]

        # Compute importance: |weight| × activation
        importance = W.abs() * act.unsqueeze(0)

        # Flatten and find threshold
        threshold = torch.quantile(importance.flatten(), sparsity)

        # Create mask
        mask = importance >= threshold

        # Apply mask (prune)
        W *= mask.float()

return model
def wanda_prune(model, calib_data, sparsity=0.5): """ Wanda: 基于权重幅度×输入激活值进行剪枝。
参数:
    sparsity: 要剪枝的权重比例(0.5 = 50%)
"""
# 1. 收集激活值统计数据
activations = {}

def hook_fn(name):
    def hook(module, input, output):
        # 存储输入激活值的范数
        activations[name] = input[0].detach().abs().mean(dim=0)
    return hook

# 为所有线性层注册钩子
hooks = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        hooks.append(module.register_forward_hook(hook_fn(name)))

# 运行校准数据
model.eval()
with torch.no_grad():
    for text in calib_data:
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        model(**inputs)

# 移除钩子
for hook in hooks:
    hook.remove()

# 2. 基于|权重|×激活值进行剪枝
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and name in activations:
        W = module.weight.data
        act = activations[name]

        # 计算重要性:|权重| × 激活值
        importance = W.abs() * act.unsqueeze(0)

        # 扁平化并找到阈值
        threshold = torch.quantile(importance.flatten(), sparsity)

        # 创建掩码
        mask = importance >= threshold

        # 应用掩码(剪枝)
        W *= mask.float()

return model

Apply Wanda pruning (50% sparsity, one-shot, no retraining)

应用Wanda剪枝(50%稀疏度,一次性剪枝,无需重新训练)

pruned_model = wanda_prune(model, calib_data, sparsity=0.5)
pruned_model = wanda_prune(model, calib_data, sparsity=0.5)

Save

保存模型

pruned_model.save_pretrained("./llama-2-7b-wanda-50")
undefined
pruned_model.save_pretrained("./llama-2-7b-wanda-50")
undefined

SparseGPT (Second-Order Pruning)

SparseGPT(二阶剪枝)

Source: arXiv 2301.00774
python
from sparsegpt import SparseGPT
来源:arXiv 2301.00774
python
from sparsegpt import SparseGPT

Load model

加载模型

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

Initialize SparseGPT

初始化SparseGPT

pruner = SparseGPT(model)
pruner = SparseGPT(model)

Calibration data

校准数据

calib_data = load_calibration_data() # ~128 samples
calib_data = load_calibration_data() # ~128样本

Prune (one-shot, layer-wise reconstruction)

剪枝(一次性剪枝,逐层重构)

pruned_model = pruner.prune( calib_data=calib_data, sparsity=0.5, # 50% sparsity prunen=0, # Unstructured (0) or N:M structured prunem=0, percdamp=0.01, # Damping for Hessian inverse )
pruned_model = pruner.prune( calib_data=calib_data, sparsity=0.5, # 50%稀疏度 prunen=0, # 非结构化(0)或N:M结构化 prunem=0, percdamp=0.01, # 海森矩阵逆的阻尼系数 )

Results: Near-lossless pruning at 50% sparsity

结果:50%稀疏度下近乎无损的剪枝

undefined
undefined

N:M Structured Pruning (Hardware Accelerator)

N:M结构化剪枝(适配硬件加速器)

python
def nm_prune(weight, n=2, m=4):
    """
    N:M pruning: Keep N weights per M consecutive weights.
    Example: 2:4 = keep 2 out of every 4 weights.

    Compatible with NVIDIA sparse tensor cores (2:4, 4:8).
    """
    # Reshape weight into groups of M
    shape = weight.shape
    weight_flat = weight.flatten()

    # Pad to multiple of M
    pad_size = (m - weight_flat.numel() % m) % m
    weight_padded = F.pad(weight_flat, (0, pad_size))

    # Reshape into (num_groups, m)
    weight_grouped = weight_padded.reshape(-1, m)

    # Find top-N in each group
    _, indices = torch.topk(weight_grouped.abs(), n, dim=-1)

    # Create mask
    mask = torch.zeros_like(weight_grouped)
    mask.scatter_(1, indices, 1.0)

    # Apply mask
    weight_pruned = weight_grouped * mask

    # Reshape back
    weight_pruned = weight_pruned.flatten()[:weight_flat.numel()]
    return weight_pruned.reshape(shape)
python
def nm_prune(weight, n=2, m=4):
    """
    N:M剪枝:每M个连续权重中保留N个。
    示例:2:4 = 每4个权重中保留2个。

    兼容NVIDIA稀疏张量核心(2:4、4:8)。
    """
    # 将权重重塑为M个一组
    shape = weight.shape
    weight_flat = weight.flatten()

    # 填充至M的倍数
    pad_size = (m - weight_flat.numel() % m) % m
    weight_padded = F.pad(weight_flat, (0, pad_size))

    # 重塑为(组数,m)
    weight_grouped = weight_padded.reshape(-1, m)

    # 找到每组中前N个权重
    _, indices = torch.topk(weight_grouped.abs(), n, dim=-1)

    # 创建掩码
    mask = torch.zeros_like(weight_grouped)
    mask.scatter_(1, indices, 1.0)

    # 应用掩码
    weight_pruned = weight_grouped * mask

    # 重塑回原形状
    weight_pruned = weight_pruned.flatten()[:weight_flat.numel()]
    return weight_pruned.reshape(shape)

Apply 2:4 sparsity (NVIDIA hardware)

应用2:4稀疏度(适配NVIDIA硬件)

for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): module.weight.data = nm_prune(module.weight.data, n=2, m=4)
for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): module.weight.data = nm_prune(module.weight.data, n=2, m=4)

50% sparsity, 2× speedup on A100 with sparse tensor cores

50%稀疏度,在A100上使用稀疏张量核心可实现2倍速度提升

undefined
undefined

Core Concepts

核心概念

1. Pruning Criteria

1. 剪枝准则

Magnitude Pruning (baseline):
python
undefined
幅度剪枝(基准方法):
python
undefined

Prune weights with smallest absolute values

剪枝绝对值最小的权重

importance = weight.abs() threshold = torch.quantile(importance, sparsity) mask = importance >= threshold

**Wanda** (weights × activations):
```python
importance = weight.abs() threshold = torch.quantile(importance, sparsity) mask = importance >= threshold

**Wanda**(权重×激活值):
```python

Importance = |weight| × input_activation

重要性 = |权重| × 输入激活值

importance = weight.abs() * activation
importance = weight.abs() * activation

Better than magnitude alone (considers usage)

优于单纯的幅度剪枝(考虑了权重的实际使用情况)


**SparseGPT** (second-order):
```python

**SparseGPT**(二阶方法):
```python

Uses Hessian (second derivative) for importance

使用海森矩阵(二阶导数)计算重要性

More accurate but computationally expensive

精度更高但计算成本更大

importance = weight^2 / diag(Hessian)
undefined
importance = weight^2 / diag(Hessian)
undefined

2. Structured vs Unstructured

2. 非结构化 vs 结构化剪枝

Unstructured (fine-grained):
  • Prune individual weights
  • Higher quality (better accuracy)
  • No hardware speedup (irregular sparsity)
Structured (coarse-grained):
  • Prune entire neurons, heads, or layers
  • Lower quality (more accuracy loss)
  • Hardware speedup (regular sparsity)
Semi-structured (N:M):
  • Best of both worlds
  • 50% sparsity (2:4) → 2× speedup on NVIDIA GPUs
  • Minimal accuracy loss
非结构化剪枝(细粒度):
  • 剪枝单个权重
  • 质量更高(精度损失更小)
  • 无硬件加速效果(稀疏度不规则)
结构化剪枝(粗粒度):
  • 剪枝整个神经元、注意力头或层
  • 质量较低(精度损失更大)
  • 有硬件加速效果(稀疏度规则)
半结构化剪枝(N:M)
  • 兼顾两者优势
  • 50%稀疏度(2:4)→ 在NVIDIA GPU上实现2倍速度提升
  • 精度损失极小

3. Sparsity Patterns

3. 稀疏度模式

python
undefined
python
undefined

Unstructured (random)

非结构化(随机)

[1, 0, 1, 0, 1, 1, 0, 0]

[1, 0, 1, 0, 1, 1, 0, 0]

Pros: Flexible, high quality

优点:灵活、质量高

Cons: No speedup

缺点:无加速效果

Structured (block)

结构化(块级)

[1, 1, 0, 0, 1, 1, 0, 0]

[1, 1, 0, 0, 1, 1, 0, 0]

Pros: Hardware friendly

优点:适配硬件

Cons: More accuracy loss

缺点:精度损失较大

N:M (semi-structured)

N:M(半结构化)

[1, 0, 1, 0] [1, 1, 0, 0] (2:4 pattern)

[1, 0, 1, 0] [1, 1, 0, 0] (2:4模式)

Pros: Hardware speedup + good quality

优点:硬件加速+高质量

Cons: Requires specific hardware (NVIDIA)

缺点:需要特定硬件(NVIDIA)

undefined
undefined

Pruning Strategies

剪枝策略

Strategy 1: Gradual Magnitude Pruning

策略1:渐进式幅度剪枝

python
def gradual_prune(model, initial_sparsity=0.0, final_sparsity=0.5, num_steps=100):
    """Gradually increase sparsity during training."""
    for step in range(num_steps):
        # Current sparsity
        current_sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * (step / num_steps)

        # Prune at current sparsity
        for module in model.modules():
            if isinstance(module, torch.nn.Linear):
                weight = module.weight.data
                threshold = torch.quantile(weight.abs().flatten(), current_sparsity)
                mask = weight.abs() >= threshold
                weight *= mask.float()

        # Train one step
        train_step(model)

    return model
python
def gradual_prune(model, initial_sparsity=0.0, final_sparsity=0.5, num_steps=100):
    """在训练过程中逐步增加稀疏度。"""
    for step in range(num_steps):
        # 当前稀疏度
        current_sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * (step / num_steps)

        # 按当前稀疏度剪枝
        for module in model.modules():
            if isinstance(module, torch.nn.Linear):
                weight = module.weight.data
                threshold = torch.quantile(weight.abs().flatten(), current_sparsity)
                mask = weight.abs() >= threshold
                weight *= mask.float()

        # 训练一步
        train_step(model)

    return model

Strategy 2: Layer-wise Pruning

策略2:逐层剪枝

python
def layer_wise_prune(model, sparsity_per_layer):
    """Different sparsity for different layers."""
    # Early layers: Less pruning (more important)
    # Late layers: More pruning (less critical)

    sparsity_schedule = {
        "layer.0": 0.3,   # 30% sparsity
        "layer.1": 0.4,
        "layer.2": 0.5,
        "layer.3": 0.6,   # 60% sparsity
    }

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Find layer index
            for layer_name, sparsity in sparsity_schedule.items():
                if layer_name in name:
                    # Prune at layer-specific sparsity
                    prune_layer(module, sparsity)
                    break

    return model
python
def layer_wise_prune(model, sparsity_per_layer):
    """为不同层设置不同的稀疏度。"""
    # 早期层:少剪枝(更重要)
    # 后期层:多剪枝(相对次要)

    sparsity_schedule = {
        "layer.0": 0.3,   # 30%稀疏度
        "layer.1": 0.4,
        "layer.2": 0.5,
        "layer.3": 0.6,   # 60%稀疏度
    }

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # 找到层索引
            for layer_name, sparsity in sparsity_schedule.items():
                if layer_name in name:
                    # 按层特定稀疏度剪枝
                    prune_layer(module, sparsity)
                    break

    return model

Strategy 3: Iterative Pruning + Fine-tuning

策略3:迭代剪枝+微调

python
def iterative_prune_finetune(model, target_sparsity=0.5, iterations=5):
    """Prune gradually with fine-tuning between iterations."""
    current_sparsity = 0.0
    sparsity_increment = target_sparsity / iterations

    for i in range(iterations):
        # Increase sparsity
        current_sparsity += sparsity_increment

        # Prune
        prune_model(model, sparsity=current_sparsity)

        # Fine-tune (recover accuracy)
        fine_tune(model, epochs=2, lr=1e-5)

    return model
python
def iterative_prune_finetune(model, target_sparsity=0.5, iterations=5):
    """逐步剪枝,迭代间进行微调。"""
    current_sparsity = 0.0
    sparsity_increment = target_sparsity / iterations

    for i in range(iterations):
        # 增加稀疏度
        current_sparsity += sparsity_increment

        # 剪枝
        prune_model(model, sparsity=current_sparsity)

        # 微调(恢复精度)
        fine_tune(model, epochs=2, lr=1e-5)

    return model

Results: Better accuracy than one-shot at high sparsity

结果:在高稀疏度下比一次性剪枝精度更高

undefined
undefined

Production Deployment

生产环境部署

Complete Pruning Pipeline

完整剪枝流程

python
from transformers import Trainer, TrainingArguments

def production_pruning_pipeline(
    model_name="meta-llama/Llama-2-7b-hf",
    target_sparsity=0.5,
    method="wanda",  # or "sparsegpt"
):
    # 1. Load model
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # 2. Load calibration data
    calib_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")

    # 3. Apply pruning
    if method == "wanda":
        pruned_model = wanda_prune(model, calib_dataset, sparsity=target_sparsity)
    elif method == "sparsegpt":
        pruner = SparseGPT(model)
        pruned_model = pruner.prune(calib_dataset, sparsity=target_sparsity)

    # 4. (Optional) Fine-tune to recover accuracy
    training_args = TrainingArguments(
        output_dir="./pruned-model",
        num_train_epochs=1,
        per_device_train_batch_size=4,
        learning_rate=1e-5,
        bf16=True,
    )

    trainer = Trainer(
        model=pruned_model,
        args=training_args,
        train_dataset=finetune_dataset,
    )

    trainer.train()

    # 5. Save
    pruned_model.save_pretrained("./pruned-llama-7b-50")
    tokenizer.save_pretrained("./pruned-llama-7b-50")

    return pruned_model
python
from transformers import Trainer, TrainingArguments

def production_pruning_pipeline(
    model_name="meta-llama/Llama-2-7b-hf",
    target_sparsity=0.5,
    method="wanda",  # 或 "sparsegpt"
):
    # 1. 加载模型
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # 2. 加载校准数据
    calib_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")

    # 3. 应用剪枝
    if method == "wanda":
        pruned_model = wanda_prune(model, calib_dataset, sparsity=target_sparsity)
    elif method == "sparsegpt":
        pruner = SparseGPT(model)
        pruned_model = pruner.prune(calib_dataset, sparsity=target_sparsity)

    # 4.(可选)微调恢复精度
    training_args = TrainingArguments(
        output_dir="./pruned-model",
        num_train_epochs=1,
        per_device_train_batch_size=4,
        learning_rate=1e-5,
        bf16=True,
    )

    trainer = Trainer(
        model=pruned_model,
        args=training_args,
        train_dataset=finetune_dataset,
    )

    trainer.train()

    # 5. 保存模型
    pruned_model.save_pretrained("./pruned-llama-7b-50")
    tokenizer.save_pretrained("./pruned-llama-7b-50")

    return pruned_model

Usage

使用示例

pruned_model = production_pruning_pipeline( model_name="meta-llama/Llama-2-7b-hf", target_sparsity=0.5, method="wanda" )
undefined
pruned_model = production_pruning_pipeline( model_name="meta-llama/Llama-2-7b-hf", target_sparsity=0.5, method="wanda" )
undefined

Evaluation

评估

python
from lm_eval import evaluator
python
from lm_eval import evaluator

Evaluate pruned vs original model

评估剪枝后模型与原模型的性能

original_results = evaluator.simple_evaluate( model="hf", model_args="pretrained=meta-llama/Llama-2-7b-hf", tasks=["arc_easy", "hellaswag", "winogrande"], )
pruned_results = evaluator.simple_evaluate( model="hf", model_args="pretrained=./pruned-llama-7b-50", tasks=["arc_easy", "hellaswag", "winogrande"], )
original_results = evaluator.simple_evaluate( model="hf", model_args="pretrained=meta-llama/Llama-2-7b-hf", tasks=["arc_easy", "hellaswag", "winogrande"], )
pruned_results = evaluator.simple_evaluate( model="hf", model_args="pretrained=./pruned-llama-7b-50", tasks=["arc_easy", "hellaswag", "winogrande"], )

Compare

对比结果

print(f"Original: {original_results['results']['arc_easy']['acc']:.3f}") print(f"Pruned: {pruned_results['results']['arc_easy']['acc']:.3f}") print(f"Degradation: {(original_results - pruned_results):.3f}")
print(f"原模型: {original_results['results']['arc_easy']['acc']:.3f}") print(f"剪枝后模型: {pruned_results['results']['arc_easy']['acc']:.3f}") print(f"精度下降: {(original_results - pruned_results):.3f}")

Typical results at 50% sparsity:

50%稀疏度下的典型结果(LLaMA-7B):

- Wanda: <1% accuracy loss

- Wanda: <1%精度损失

- SparseGPT: <0.5% accuracy loss

- SparseGPT: <0.5%精度损失

- Magnitude: 2-3% accuracy loss

- 幅度剪枝: 2-3%精度损失

undefined
undefined

Best Practices

最佳实践

1. Sparsity Selection

1. 稀疏度选择

python
undefined
python
undefined

Conservative (safe)

保守型(安全)

sparsity = 0.3 # 30%, <0.5% loss
sparsity = 0.3 # 30%, <0.5%精度损失

Balanced (recommended)

平衡型(推荐)

sparsity = 0.5 # 50%, ~1% loss
sparsity = 0.5 # 50%, ~1%精度损失

Aggressive (risky)

激进型(有风险)

sparsity = 0.7 # 70%, 2-5% loss
sparsity = 0.7 # 70%, 2-5%精度损失

Extreme (model-dependent)

极端型(依赖模型)

sparsity = 0.9 # 90%, significant degradation
undefined
sparsity = 0.9 # 90%, 精度显著下降
undefined

2. Method Selection

2. 方法选择

python
undefined
python
undefined

One-shot, no retraining → Wanda or SparseGPT

一次性剪枝,无需重新训练 → Wanda或SparseGPT

if no_retraining_budget: use_method = "wanda" # Faster
if no_retraining_budget: use_method = "wanda" # 速度更快

Best quality → SparseGPT

追求最佳精度 → SparseGPT

if need_best_quality: use_method = "sparsegpt" # More accurate
if need_best_quality: use_method = "sparsegpt" # 精度更高

Hardware speedup → N:M structured

需要硬件加速 → N:M结构化剪枝

if need_speedup: use_method = "nm_prune" # 2:4 or 4:8
undefined
if need_speedup: use_method = "nm_prune" # 2:4或4:8模式
undefined

3. Avoid Common Pitfalls

3. 避免常见误区

python
undefined
python
undefined

❌ Bad: Pruning without calibration data

❌ 错误:无校准数据直接剪枝

prune_random(model) # No activation statistics
prune_random(model) # 无激活值统计数据

✅ Good: Use calibration data

✅ 正确:使用校准数据

prune_wanda(model, calib_data)
prune_wanda(model, calib_data)

❌ Bad: Too high sparsity in one shot

❌ 错误:一次性剪枝度过高

prune(model, sparsity=0.9) # Massive accuracy loss
prune(model, sparsity=0.9) # 精度大幅损失

✅ Good: Gradual or iterative

✅ 正确:渐进式或迭代式剪枝

iterative_prune(model, target=0.9, steps=10)
undefined
iterative_prune(model, target=0.9, steps=10)
undefined

Performance Comparison

性能对比

Pruning methods at 50% sparsity (LLaMA-7B):
MethodAccuracy LossSpeedMemoryRetraining Needed
Magnitude-2.5%1.0×-50%No
Wanda-0.8%1.0×-50%No
SparseGPT-0.4%1.0×-50%No
N:M (2:4)-1.0%2.0×-50%No
Structured-3.0%2.0×-50%No
Source: Wanda paper (ICLR 2024), SparseGPT paper
50%稀疏度下的剪枝方法对比(LLaMA-7B):
方法精度损失速度内存占用是否需要重新训练
幅度剪枝-2.5%1.0×-50%
Wanda-0.8%1.0×-50%
SparseGPT-0.4%1.0×-50%
N:M (2:4)-1.0%2.0×-50%
结构化剪枝-3.0%2.0×-50%
来源:Wanda论文(ICLR 2024)、SparseGPT论文

Resources

参考资源