optimizing-attention-flash

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Flash Attention - Fast Memory-Efficient Attention

Flash Attention - 快速且内存高效的注意力机制

Quick start

快速开始

Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
PyTorch native (easiest, PyTorch 2.2+):
python
import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)  # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
Flash Attention通过IO感知分块和重计算技术,为Transformer注意力机制带来2-4倍的速度提升和10-20倍的内存占用降低。
PyTorch原生方式(最简单,需PyTorch 2.2+):
python
import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)  # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)

Automatically uses Flash Attention if available

Automatically uses Flash Attention if available

out = F.scaled_dot_product_attention(q, k, v)

**flash-attn library (more features)**:
```bash
pip install flash-attn --no-build-isolation
python
from flash_attn import flash_attn_func
out = F.scaled_dot_product_attention(q, k, v)

**flash-attn库(功能更丰富)**:
```bash
pip install flash-attn --no-build-isolation
python
from flash_attn import flash_attn_func

q, k, v: [batch, seqlen, nheads, headdim]

q, k, v: [batch, seqlen, nheads, headdim]

out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
undefined
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
undefined

Common workflows

常见工作流

Workflow 1: Enable in existing PyTorch model

工作流1:在现有PyTorch模型中启用Flash Attention

Copy this checklist:
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline
Step 1: Check PyTorch version
bash
python -c "import torch; print(torch.__version__)"
复制以下检查清单:
Flash Attention集成:
- [ ] 步骤1:检查PyTorch版本(≥2.2)
- [ ] 步骤2:启用Flash Attention后端
- [ ] 步骤3:通过性能分析验证速度提升
- [ ] 步骤4:测试精度是否与基准一致
步骤1:检查PyTorch版本
bash
python -c "import torch; print(torch.__version__)"

Should be ≥2.2.0

Should be ≥2.2.0


If <2.2, upgrade:
```bash
pip install --upgrade torch
Step 2: Enable Flash Attention backend
Replace standard attention:
python
undefined

如果版本<2.2,进行升级:
```bash
pip install --upgrade torch
步骤2:启用Flash Attention后端
替换标准注意力实现:
python
undefined

Before (standard attention)

Before (standard attention)

attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1) out = attn_weights @ v
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1) out = attn_weights @ v

After (Flash Attention)

After (Flash Attention)

import torch.nn.functional as F out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)

Force Flash Attention backend:
```python
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    out = F.scaled_dot_product_attention(q, k, v)
Step 3: Verify speedup with profiling
python
import torch.utils.benchmark as benchmark

def test_attention(use_flash):
    q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

    if use_flash:
        with torch.backends.cuda.sdp_kernel(enable_flash=True):
            return F.scaled_dot_product_attention(q, k, v)
    else:
        attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
        return attn @ v
import torch.nn.functional as F out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)

强制使用Flash Attention后端:
```python
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    out = F.scaled_dot_product_attention(q, k, v)
步骤3:通过性能分析验证速度提升
python
import torch.utils.benchmark as benchmark

def test_attention(use_flash):
    q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

    if use_flash:
        with torch.backends.cuda.sdp_kernel(enable_flash=True):
            return F.scaled_dot_product_attention(q, k, v)
    else:
        attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
        return attn @ v

Benchmark

Benchmark

t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals()) t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s") print(f"Standard: {t_standard.timeit(100).mean:.3f}s")

Expected: 2-4x speedup for sequences >512 tokens.

**Step 4: Test accuracy matches baseline**

```python
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals()) t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s") print(f"Standard: {t_standard.timeit(100).mean:.3f}s")

预期结果:对于序列长度>512的token,可实现2-4倍的速度提升。

**步骤4:测试精度是否与基准一致**

```python

Compare outputs

Compare outputs

q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

Flash Attention

Flash Attention

out_flash = F.scaled_dot_product_attention(q, k, v)
out_flash = F.scaled_dot_product_attention(q, k, v)

Standard attention

Standard attention

attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1) out_standard = attn_weights @ v
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1) out_standard = attn_weights @ v

Check difference

Check difference

diff = (out_flash - out_standard).abs().max() print(f"Max difference: {diff:.6f}")
diff = (out_flash - out_standard).abs().max() print(f"Max difference: {diff:.6f}")

Should be <1e-3 for float16

Should be <1e-3 for float16

undefined
undefined

Workflow 2: Use flash-attn library for advanced features

工作流2:使用flash-attn库实现高级功能

For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance
Step 1: Install flash-attn library
bash
undefined
适用于多查询注意力、滑动窗口注意力或H100 FP8优化场景。
复制以下检查清单:
flash-attn库设置:
- [ ] 步骤1:安装flash-attn库
- [ ] 步骤2:修改注意力代码
- [ ] 步骤3:启用高级功能
- [ ] 步骤4:进行性能基准测试
步骤1:安装flash-attn库
bash
undefined

NVIDIA GPUs (CUDA 12.0+)

NVIDIA GPUs (CUDA 12.0+)

pip install flash-attn --no-build-isolation
pip install flash-attn --no-build-isolation

Verify installation

Verify installation

python -c "from flash_attn import flash_attn_func; print('Success')"

**Step 2: Modify attention code**

```python
from flash_attn import flash_attn_func
python -c "from flash_attn import flash_attn_func; print('Success')"

**步骤2:修改注意力代码**

```python
from flash_attn import flash_attn_func

Input: [batch_size, seq_len, num_heads, head_dim]

Input: [batch_size, seq_len, num_heads, head_dim]

Transpose from [batch, heads, seq, dim] if needed

Transpose from [batch, heads, seq, dim] if needed

q = q.transpose(1, 2) # [batch, seq, heads, dim] k = k.transpose(1, 2) v = v.transpose(1, 2)
out = flash_attn_func( q, k, v, dropout_p=0.1, causal=True, # For autoregressive models window_size=(-1, -1), # No sliding window softmax_scale=None # Auto-scale )
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]

**Step 3: Enable advanced features**

Multi-query attention (shared K/V across heads):
```python
from flash_attn import flash_attn_func
q = q.transpose(1, 2) # [batch, seq, heads, dim] k = k.transpose(1, 2) v = v.transpose(1, 2)
out = flash_attn_func( q, k, v, dropout_p=0.1, causal=True, # For autoregressive models window_size=(-1, -1), # No sliding window softmax_scale=None # Auto-scale )
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]

**步骤3:启用高级功能**

多查询注意力(所有注意力头共享K/V):
```python
from flash_attn import flash_attn_func

q: [batch, seq, num_q_heads, dim]

q: [batch, seq, num_q_heads, dim]

k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads

k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads

out = flash_attn_func(q, k, v) # Automatically handles MQA

Sliding window attention (local attention):
```python
out = flash_attn_func(q, k, v) # Automatically handles MQA

滑动窗口注意力(局部注意力):
```python

Only attend to window of 256 tokens before/after

Only attend to window of 256 tokens before/after

out = flash_attn_func( q, k, v, window_size=(256, 256), # (left, right) window causal=True )

**Step 4: Benchmark performance**

```python
import torch
from flash_attn import flash_attn_func
import time

q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
out = flash_attn_func( q, k, v, window_size=(256, 256), # (left, right) window causal=True )

**步骤4:进行性能基准测试**

```python
import torch
from flash_attn import flash_attn_func
import time

q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

Warmup

Warmup

for _ in range(10): _ = flash_attn_func(q, k, v)
for _ in range(10): _ = flash_attn_func(q, k, v)

Benchmark

Benchmark

torch.cuda.synchronize() start = time.time() for _ in range(100): out = flash_attn_func(q, k, v) torch.cuda.synchronize() end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms") print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
undefined
torch.cuda.synchronize() start = time.time() for _ in range(100): out = flash_attn_func(q, k, v) torch.cuda.synchronize() end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms") print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
undefined

Workflow 3: H100 FP8 optimization (FlashAttention-3)

工作流3:H100 FP8优化(FlashAttention-3)

For maximum performance on H100 GPUs.
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention
Step 1: Verify H100 GPU
bash
nvidia-smi --query-gpu=name --format=csv
在H100 GPU上实现极致性能。
FP8设置:
- [ ] 步骤1:确认H100 GPU可用
- [ ] 步骤2:安装支持FP8的flash-attn库
- [ ] 步骤3:将输入转换为FP8格式
- [ ] 步骤4:运行FP8注意力计算
步骤1:确认H100 GPU可用
bash
nvidia-smi --query-gpu=name --format=csv

Should show "H100" or "H800"

Should show "H100" or "H800"


**Step 2: Install flash-attn with FP8 support**

```bash
pip install flash-attn --no-build-isolation

**步骤2:安装支持FP8的flash-attn库**

```bash
pip install flash-attn --no-build-isolation

FP8 support included for H100

FP8 support included for H100


**Step 3: Convert inputs to FP8**

```python
import torch

q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)

**步骤3:将输入转换为FP8格式**

```python
import torch

q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)

Convert to float8_e4m3 (FP8)

Convert to float8_e4m3 (FP8)

q_fp8 = q.to(torch.float8_e4m3fn) k_fp8 = k.to(torch.float8_e4m3fn) v_fp8 = v.to(torch.float8_e4m3fn)

**Step 4: Run with FP8 attention**

```python
from flash_attn import flash_attn_func
q_fp8 = q.to(torch.float8_e4m3fn) k_fp8 = k.to(torch.float8_e4m3fn) v_fp8 = v.to(torch.float8_e4m3fn)

**步骤4:运行FP8注意力计算**

```python
from flash_attn import flash_attn_func

FlashAttention-3 automatically uses FP8 kernels on H100

FlashAttention-3 automatically uses FP8 kernels on H100

out = flash_attn_func(q_fp8, k_fp8, v_fp8)
out = flash_attn_func(q_fp8, k_fp8, v_fp8)

Result: ~1.2 PFLOPS, 1.5-2x faster than FP16

Result: ~1.2 PFLOPS, 1.5-2x faster than FP16

undefined
undefined

When to use vs alternatives

使用场景与替代方案对比

Use Flash Attention when:
  • Training transformers with sequences >512 tokens
  • Running inference with long context (>2K tokens)
  • GPU memory constrained (OOM with standard attention)
  • Need 2-4x speedup without accuracy loss
  • Using PyTorch 2.2+ or can install flash-attn
Use alternatives instead:
  • Standard attention: Sequences <256 tokens (overhead not worth it)
  • xFormers: Need more attention variants (not just speed)
  • Memory-efficient attention: CPU inference (Flash Attention needs GPU)
推荐使用Flash Attention的场景:
  • 训练长序列(>512个token)的Transformer模型
  • 运行长上下文(>2K个token)的推理任务
  • GPU内存受限(标准注意力机制出现OOM错误)
  • 需要在不损失精度的前提下实现2-4倍的速度提升
  • 使用PyTorch 2.2+版本或可安装flash-attn库
推荐使用替代方案的场景:
  • 标准注意力机制: 序列长度<256个token(Flash Attention的开销得不偿失)
  • xFormers: 需要更多注意力变体(而非仅追求速度)
  • 内存高效注意力: CPU推理场景(Flash Attention依赖GPU)

Common issues

常见问题

Issue: ImportError: cannot import flash_attn
Install with no-build-isolation flag:
bash
pip install flash-attn --no-build-isolation
Or install CUDA toolkit first:
bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
Issue: Slower than expected (no speedup)
Flash Attention benefits increase with sequence length:
  • <512 tokens: Minimal speedup (10-20%)
  • 512-2K tokens: 2-3x speedup
  • 2K tokens: 3-4x speedup
Check sequence length is sufficient.
Issue: RuntimeError: CUDA error
Verify GPU supports Flash Attention:
python
import torch
print(torch.cuda.get_device_capability())
问题:ImportError: cannot import flash_attn
使用no-build-isolation参数重新安装:
bash
pip install flash-attn --no-build-isolation
或先安装CUDA工具包:
bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
问题:速度未达预期(无明显提升)
Flash Attention的性能收益随序列长度增加而提升:
  • <512个token: 提升幅度极小(10-20%)
  • 512-2K个token: 2-3倍速度提升
  • 2K个token: 3-4倍速度提升
请检查序列长度是否足够。
问题:RuntimeError: CUDA error
验证GPU是否支持Flash Attention:
python
import torch
print(torch.cuda.get_device_capability())

Should be ≥(7, 5) for Turing+

Should be ≥(7, 5) for Turing+


Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported

**Issue: Accuracy degradation**

Check dtype is float16 or bfloat16 (not float32):
```python
q = q.to(torch.float16)  # Or torch.bfloat16
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.

Flash Attention支持以下GPU:
- Ampere架构(A100、A10、A30): ✅ 完全支持
- Turing架构(T4): ✅ 支持
- Volta架构(V100): ❌ 不支持

**问题:精度下降**

检查数据类型是否为float16或bfloat16(不支持float32):
```python
q = q.to(torch.float16)  # Or torch.bfloat16
Flash Attention为追求速度使用float16/bfloat16数据类型,不支持float32。

Advanced topics

高级主题

Integration with HuggingFace Transformers: See references/transformers-integration.md for enabling Flash Attention in BERT, GPT, Llama models.
Performance benchmarks: See references/benchmarks.md for detailed speed and memory comparisons across GPUs and sequence lengths.
Algorithm details: See references/algorithm.md for tiling strategy, recomputation, and IO complexity analysis.
Advanced features: See references/advanced-features.md for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
与HuggingFace Transformers集成: 参考references/transformers-integration.md了解如何在BERT、GPT、Llama模型中启用Flash Attention。
性能基准测试: 参考references/benchmarks.md查看不同GPU和序列长度下的详细速度与内存对比数据。
算法细节: 参考references/algorithm.md了解分块策略、重计算技术以及IO复杂度分析。
高级功能: 参考references/advanced-features.md了解旋转位置编码、ALiBi、分页KV缓存以及自定义注意力掩码等功能。

Hardware requirements

硬件要求

  • GPU: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
  • VRAM: Same as standard attention (Flash Attention doesn't increase memory)
  • CUDA: 12.0+ (11.8 minimum)
  • PyTorch: 2.2+ for native support
Not supported: V100 (Volta), CPU inference
  • GPU: NVIDIA Ampere+(A100、A10、A30)或AMD MI200+
  • VRAM: 与标准注意力机制要求相同(Flash Attention不会增加内存占用)
  • CUDA: 12.0+(最低要求11.8)
  • PyTorch: 2.2+(原生支持)
不支持: V100(Volta架构)、CPU推理

Resources

参考资源