segment-anything-model

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Segment Anything Model (SAM)

Segment Anything Model(SAM)

Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
Meta AI推出的Segment Anything Model零样本图像分割使用全指南。

When to use SAM

何时使用SAM

Use SAM when:
  • Need to segment any object in images without task-specific training
  • Building interactive annotation tools with point/box prompts
  • Generating training data for other vision models
  • Need zero-shot transfer to new image domains
  • Building object detection/segmentation pipelines
  • Processing medical, satellite, or domain-specific images
Key features:
  • Zero-shot segmentation: Works on any image domain without fine-tuning
  • Flexible prompts: Points, bounding boxes, or previous masks
  • Automatic segmentation: Generate all object masks automatically
  • High quality: Trained on 1.1 billion masks from 11 million images
  • Multiple model sizes: ViT-B (fastest), ViT-L, ViT-H (most accurate)
  • ONNX export: Deploy in browsers and edge devices
Use alternatives instead:
  • YOLO/Detectron2: For real-time object detection with classes
  • Mask2Former: For semantic/panoptic segmentation with categories
  • GroundingDINO + SAM: For text-prompted segmentation
  • SAM 2: For video segmentation tasks
在以下场景使用SAM:
  • 需要在不进行任务特定训练的情况下分割图像中的任意对象
  • 构建带有点/框提示的交互式标注工具
  • 为其他视觉模型生成训练数据
  • 需要向新图像领域进行零样本迁移
  • 构建目标检测/分割流水线
  • 处理医学、卫星或特定领域的图像
核心特性:
  • 零样本分割:无需微调即可在任意图像领域工作
  • 灵活的提示方式:支持点、边界框或之前的掩码
  • 自动分割:自动生成所有对象的掩码
  • 高质量:基于1100万张图像中的11亿个掩码训练而成
  • 多种模型尺寸:ViT-B(速度最快)、ViT-L、ViT-H(精度最高)
  • ONNX导出:可在浏览器和边缘设备部署
以下场景可使用替代方案:
  • YOLO/Detectron2:用于带类别信息的实时目标检测
  • Mask2Former:用于带类别的语义/全景分割
  • GroundingDINO + SAM:用于文本提示的分割
  • SAM 2:用于视频分割任务

Quick start

快速开始

Installation

安装

bash
undefined
bash
undefined

From GitHub

从GitHub安装

Optional dependencies

可选依赖

pip install opencv-python pycocotools matplotlib
pip install opencv-python pycocotools matplotlib

Or use HuggingFace transformers

或使用HuggingFace Transformers

pip install transformers
undefined
pip install transformers
undefined

Download checkpoints

下载权重文件

bash
undefined
bash
undefined

ViT-H (largest, most accurate) - 2.4GB

ViT-H(最大,精度最高)- 2.4GB

ViT-L (medium) - 1.2GB

ViT-L(中等)- 1.2GB

ViT-B (smallest, fastest) - 375MB

ViT-B(最小,速度最快)- 375MB

Basic usage with SamPredictor

使用SamPredictor的基础用法

python
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
python
import numpy as np
from segment_anything import sam_model_registry, SamPredictor

Load model

加载模型

sam = sam_model_registry"vit_h" sam.to(device="cuda")
sam = sam_model_registry"vit_h" sam.to(device="cuda")

Create predictor

创建预测器

predictor = SamPredictor(sam)
predictor = SamPredictor(sam)

Set image (computes embeddings once)

设置图像(仅计算一次嵌入)

image = cv2.imread("image.jpg") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(image)
image = cv2.imread("image.jpg") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(image)

Predict with point prompts

使用点提示进行预测

input_point = np.array([[500, 375]]) # (x, y) coordinates input_label = np.array([1]) # 1 = foreground, 0 = background
masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True # Returns 3 mask options )
input_point = np.array([[500, 375]]) # (x, y)坐标 input_label = np.array([1]) # 1 = 前景,0 = 背景
masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True # 返回3种掩码选项 )

Select best mask

选择最佳掩码

best_mask = masks[np.argmax(scores)]
undefined
best_mask = masks[np.argmax(scores)]
undefined

HuggingFace Transformers

HuggingFace Transformers用法

python
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
python
import torch
from PIL import Image
from transformers import SamModel, SamProcessor

Load model and processor

加载模型和处理器

model = SamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") model.to("cuda")
model = SamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") model.to("cuda")

Process image with point prompt

处理带点提示的图像

image = Image.open("image.jpg") input_points = [[[450, 600]]] # Batch of points
inputs = processor(image, input_points=input_points, return_tensors="pt") inputs = {k: v.to("cuda") for k, v in inputs.items()}
image = Image.open("image.jpg") input_points = [[[450, 600]]] # 点的批次
inputs = processor(image, input_points=input_points, return_tensors="pt") inputs = {k: v.to("cuda") for k, v in inputs.items()}

Generate masks

生成掩码

with torch.no_grad(): outputs = model(**inputs)
with torch.no_grad(): outputs = model(**inputs)

Post-process masks to original size

将掩码后处理为原始尺寸

masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )
undefined
masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )
undefined

Core concepts

核心概念

Model architecture

模型架构

SAM Architecture:
┌─────────────────┐     ┌─────────────────┐     ┌─────────────────┐
│  Image Encoder  │────▶│ Prompt Encoder  │────▶│  Mask Decoder   │
│     (ViT)       │     │ (Points/Boxes)  │     │ (Transformer)   │
└─────────────────┘     └─────────────────┘     └─────────────────┘
        │                       │                       │
   Image Embeddings      Prompt Embeddings         Masks + IoU
   (computed once)       (per prompt)             predictions
SAM架构:
┌─────────────────┐     ┌─────────────────┐     ┌─────────────────┐
│  Image Encoder  │────▶│ Prompt Encoder  │────▶│  Mask Decoder   │
│     (ViT)       │     │ (Points/Boxes)  │     │ (Transformer)   │
└─────────────────┘     └─────────────────┘     └─────────────────┘
        │                       │                       │
   Image Embeddings      Prompt Embeddings         Masks + IoU
   (仅计算一次)       (每个提示对应一次)             预测结果

Model variants

模型变体

ModelCheckpointSizeSpeedAccuracy
ViT-H
vit_h
2.4 GBSlowestBest
ViT-L
vit_l
1.2 GBMediumGood
ViT-B
vit_b
375 MBFastestGood
模型权重标识大小速度精度
ViT-H
vit_h
2.4 GB最慢最佳
ViT-L
vit_l
1.2 GB中等良好
ViT-B
vit_b
375 MB最快良好

Prompt types

提示类型

PromptDescriptionUse Case
Point (foreground)Click on objectSingle object selection
Point (background)Click outside objectExclude regions
Bounding boxRectangle around objectLarger objects
Previous maskLow-res mask inputIterative refinement
提示方式描述使用场景
点(前景)点击目标对象单个对象选择
点(背景)点击对象外部区域排除指定区域
边界框围绕对象绘制矩形较大的对象
之前的掩码低分辨率掩码输入迭代优化

Interactive segmentation

交互式分割

Point prompts

点提示

python
undefined
python
undefined

Single foreground point

单个前景点

input_point = np.array([[500, 375]]) input_label = np.array([1])
masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True )
input_point = np.array([[500, 375]]) input_label = np.array([1])
masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True )

Multiple points (foreground + background)

多个点(前景+背景)

input_points = np.array([[500, 375], [600, 400], [450, 300]]) input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
masks, scores, logits = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False # Single mask when prompts are clear )
undefined
input_points = np.array([[500, 375], [600, 400], [450, 300]]) input_labels = np.array([1, 1, 0]) # 2个前景点,1个背景点
masks, scores, logits = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False # 当提示明确时返回单个掩码 )
undefined

Box prompts

框提示

python
undefined
python
undefined

Bounding box [x1, y1, x2, y2]

边界框 [x1, y1, x2, y2]

input_box = np.array([425, 600, 700, 875])
masks, scores, logits = predictor.predict( box=input_box, multimask_output=False )
undefined
input_box = np.array([425, 600, 700, 875])
masks, scores, logits = predictor.predict( box=input_box, multimask_output=False )
undefined

Combined prompts

组合提示

python
undefined
python
undefined

Box + points for precise control

框+点实现精确控制

masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), box=np.array([400, 300, 700, 600]), multimask_output=False )
undefined
masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), box=np.array([400, 300, 700, 600]), multimask_output=False )
undefined

Iterative refinement

迭代优化

python
undefined
python
undefined

Initial prediction

初始预测

masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True )
masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True )

Refine with additional point using previous mask

使用之前的掩码结合额外点进行优化

masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375], [550, 400]]), point_labels=np.array([1, 0]), # Add background point mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask multimask_output=False )
undefined
masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375], [550, 400]]), point_labels=np.array([1, 0]), # 添加背景点 mask_input=logits[np.argmax(scores)][None, :, :], # 使用最佳掩码 multimask_output=False )
undefined

Automatic mask generation

自动掩码生成

Basic automatic segmentation

基础自动分割

python
from segment_anything import SamAutomaticMaskGenerator
python
from segment_anything import SamAutomaticMaskGenerator

Create generator

创建生成器

mask_generator = SamAutomaticMaskGenerator(sam)
mask_generator = SamAutomaticMaskGenerator(sam)

Generate all masks

生成所有掩码

masks = mask_generator.generate(image)
masks = mask_generator.generate(image)

Each mask contains:

每个掩码包含以下信息:

- segmentation: binary mask

- segmentation: 二值掩码

- bbox: [x, y, w, h]

- bbox: [x, y, w, h]

- area: pixel count

- area: 像素数量

- predicted_iou: quality score

- predicted_iou: 质量得分

- stability_score: robustness score

- stability_score: 鲁棒性得分

- point_coords: generating point

- point_coords: 生成时使用的点

undefined
undefined

Customized generation

自定义生成

python
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,          # Grid density (more = more masks)
    pred_iou_thresh=0.88,        # Quality threshold
    stability_score_thresh=0.95,  # Stability threshold
    crop_n_layers=1,             # Multi-scale crops
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,    # Remove tiny masks
)

masks = mask_generator.generate(image)
python
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,          # 网格密度(值越大,掩码数量越多)
    pred_iou_thresh=0.88,        # 质量阈值
    stability_score_thresh=0.95,  # 鲁棒性阈值
    crop_n_layers=1,             # 多尺度裁剪
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,    # 移除小掩码
)

masks = mask_generator.generate(image)

Filtering masks

掩码过滤

python
undefined
python
undefined

Sort by area (largest first)

按面积排序(从大到小)

masks = sorted(masks, key=lambda x: x['area'], reverse=True)
masks = sorted(masks, key=lambda x: x['area'], reverse=True)

Filter by predicted IoU

按预测IoU过滤

high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]

Filter by stability score

按稳定性得分过滤

stable_masks = [m for m in masks if m['stability_score'] > 0.95]
undefined
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
undefined

Batched inference

批量推理

Multiple images

多图像处理

python
undefined
python
undefined

Process multiple images efficiently

高效处理多张图像

images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
all_masks = [] for image in images: predictor.set_image(image) masks, _, _ = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True ) all_masks.append(masks)
undefined
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
all_masks = [] for image in images: predictor.set_image(image) masks, _, _ = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True ) all_masks.append(masks)
undefined

Multiple prompts per image

单图像多提示处理

python
undefined
python
undefined

Process multiple prompts efficiently (one image encoding)

高效处理多个提示(仅需一次图像编码)

predictor.set_image(image)
predictor.set_image(image)

Batch of point prompts

点提示批次

points = [ np.array([[100, 100]]), np.array([[200, 200]]), np.array([[300, 300]]) ]
all_masks = [] for point in points: masks, scores, _ = predictor.predict( point_coords=point, point_labels=np.array([1]), multimask_output=True ) all_masks.append(masks[np.argmax(scores)])
undefined
points = [ np.array([[100, 100]]), np.array([[200, 200]]), np.array([[300, 300]]) ]
all_masks = [] for point in points: masks, scores, _ = predictor.predict( point_coords=point, point_labels=np.array([1]), multimask_output=True ) all_masks.append(masks[np.argmax(scores)])
undefined

ONNX deployment

ONNX部署

Export model

导出模型

bash
python scripts/export_onnx_model.py \
    --checkpoint sam_vit_h_4b8939.pth \
    --model-type vit_h \
    --output sam_onnx.onnx \
    --return-single-mask
bash
python scripts/export_onnx_model.py \
    --checkpoint sam_vit_h_4b8939.pth \
    --model-type vit_h \
    --output sam_onnx.onnx \
    --return-single-mask

Use ONNX model

使用ONNX模型

python
import onnxruntime
python
import onnxruntime

Load ONNX model

加载ONNX模型

ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")

Run inference (image embeddings computed separately)

运行推理(图像嵌入需单独计算)

masks = ort_session.run( None, { "image_embeddings": image_embeddings, "point_coords": point_coords, "point_labels": point_labels, "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), "has_mask_input": np.array([0], dtype=np.float32), "orig_im_size": np.array([h, w], dtype=np.float32) } )
undefined
masks = ort_session.run( None, { "image_embeddings": image_embeddings, "point_coords": point_coords, "point_labels": point_labels, "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), "has_mask_input": np.array([0], dtype=np.float32), "orig_im_size": np.array([h, w], dtype=np.float32) } )
undefined

Common workflows

常见工作流

Workflow 1: Annotation tool

工作流1:标注工具

python
import cv2
python
import cv2

Load model

加载模型

predictor = SamPredictor(sam) predictor.set_image(image)
def on_click(event, x, y, flags, param): if event == cv2.EVENT_LBUTTONDOWN: # Foreground point masks, scores, _ = predictor.predict( point_coords=np.array([[x, y]]), point_labels=np.array([1]), multimask_output=True ) # Display best mask display_mask(masks[np.argmax(scores)])
undefined
predictor = SamPredictor(sam) predictor.set_image(image)
def on_click(event, x, y, flags, param): if event == cv2.EVENT_LBUTTONDOWN: # 前景点 masks, scores, _ = predictor.predict( point_coords=np.array([[x, y]]), point_labels=np.array([1]), multimask_output=True ) # 显示最佳掩码 display_mask(masks[np.argmax(scores)])
undefined

Workflow 2: Object extraction

工作流2:对象提取

python
def extract_object(image, point):
    """Extract object at point with transparent background."""
    predictor.set_image(image)

    masks, scores, _ = predictor.predict(
        point_coords=np.array([point]),
        point_labels=np.array([1]),
        multimask_output=True
    )

    best_mask = masks[np.argmax(scores)]

    # Create RGBA output
    rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
    rgba[:, :, :3] = image
    rgba[:, :, 3] = best_mask * 255

    return rgba
python
def extract_object(image, point):
    """提取指定点处的对象,背景透明。"""
    predictor.set_image(image)

    masks, scores, _ = predictor.predict(
        point_coords=np.array([point]),
        point_labels=np.array([1]),
        multimask_output=True
    )

    best_mask = masks[np.argmax(scores)]

    # 创建RGBA输出
    rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
    rgba[:, :, :3] = image
    rgba[:, :, 3] = best_mask * 255

    return rgba

Workflow 3: Medical image segmentation

工作流3:医学图像分割

python
undefined
python
undefined

Process medical images (grayscale to RGB)

处理医学图像(灰度图转RGB)

medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE) rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
predictor.set_image(rgb_image)
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE) rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
predictor.set_image(rgb_image)

Segment region of interest

分割感兴趣区域

masks, scores, _ = predictor.predict( box=np.array([x1, y1, x2, y2]), # ROI bounding box multimask_output=True )
undefined
masks, scores, _ = predictor.predict( box=np.array([x1, y1, x2, y2]), # 感兴趣区域的边界框 multimask_output=True )
undefined

Output format

输出格式

Mask data structure

掩码数据结构

python
undefined
python
undefined

SamAutomaticMaskGenerator output

SamAutomaticMaskGenerator输出

{ "segmentation": np.ndarray, # H×W binary mask "bbox": [x, y, w, h], # Bounding box "area": int, # Pixel count "predicted_iou": float, # 0-1 quality score "stability_score": float, # 0-1 robustness score "crop_box": [x, y, w, h], # Generation crop region "point_coords": [[x, y]], # Input point }
undefined
{ "segmentation": np.ndarray, # H×W二值掩码 "bbox": [x, y, w, h], # 边界框 "area": int, # 像素数量 "predicted_iou": float, # 0-1质量得分 "stability_score": float, # 0-1鲁棒性得分 "crop_box": [x, y, w, h], # 生成时使用的裁剪区域 "point_coords": [[x, y]], # 输入点 }
undefined

COCO RLE format

COCO RLE格式

python
from pycocotools import mask as mask_utils
python
from pycocotools import mask as mask_utils

Encode mask to RLE

将掩码编码为RLE

rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8))) rle["counts"] = rle["counts"].decode("utf-8")
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8))) rle["counts"] = rle["counts"].decode("utf-8")

Decode RLE to mask

将RLE解码为掩码

decoded_mask = mask_utils.decode(rle)
undefined
decoded_mask = mask_utils.decode(rle)
undefined

Performance optimization

性能优化

GPU memory

GPU内存优化

python
undefined
python
undefined

Use smaller model for limited VRAM

在显存有限时使用更小的模型

sam = sam_model_registry"vit_b"
sam = sam_model_registry"vit_b"

Process images in batches

批量处理图像

Clear CUDA cache between large batches

在大批次处理之间清理CUDA缓存

torch.cuda.empty_cache()
undefined
torch.cuda.empty_cache()
undefined

Speed optimization

速度优化

python
undefined
python
undefined

Use half precision

使用半精度

sam = sam.half()
sam = sam.half()

Reduce points for automatic generation

减少自动生成时的点数量

mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=16, # Default is 32 )
mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=16, # 默认值为32 )

Use ONNX for deployment

使用ONNX进行部署

Export with --return-single-mask for faster inference

导出时添加--return-single-mask参数以加快推理速度

undefined
undefined

Common issues

常见问题

IssueSolution
Out of memoryUse ViT-B model, reduce image size
Slow inferenceUse ViT-B, reduce points_per_side
Poor mask qualityTry different prompts, use box + points
Edge artifactsUse stability_score filtering
Small objects missedIncrease points_per_side
问题解决方案
内存不足使用ViT-B模型,缩小图像尺寸
推理速度慢使用ViT-B,减少points_per_side的值
掩码质量差尝试不同的提示方式,使用框+点组合提示
边缘伪影使用stability_score进行过滤
小对象被遗漏增大points_per_side的值

References

参考资料

  • Advanced Usage - Batching, fine-tuning, integration
  • Troubleshooting - Common issues and solutions
  • 高级用法 - 批量处理、微调、集成
  • 故障排除 - 常见问题及解决方案

Resources

资源