optimizing-attention-flash
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseFlash Attention - Fast Memory-Efficient Attention
Flash Attention - 快速且内存高效的注意力机制
Quick start
快速开始
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
PyTorch native (easiest, PyTorch 2.2+):
python
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)Flash Attention通过IO感知分块和重计算技术,为Transformer注意力机制带来2-4倍的速度提升和10-20倍的内存占用降低。
PyTorch原生方式(最简单,需PyTorch 2.2+):
python
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)Automatically uses Flash Attention if available
Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
**flash-attn library (more features)**:
```bash
pip install flash-attn --no-build-isolationpython
from flash_attn import flash_attn_funcout = F.scaled_dot_product_attention(q, k, v)
**flash-attn库(功能更丰富)**:
```bash
pip install flash-attn --no-build-isolationpython
from flash_attn import flash_attn_funcq, k, v: [batch, seqlen, nheads, headdim]
q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
undefinedout = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
undefinedCommon workflows
常见工作流
Workflow 1: Enable in existing PyTorch model
工作流1:在现有PyTorch模型中启用Flash Attention
Copy this checklist:
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baselineStep 1: Check PyTorch version
bash
python -c "import torch; print(torch.__version__)"复制以下检查清单:
Flash Attention集成:
- [ ] 步骤1:检查PyTorch版本(≥2.2)
- [ ] 步骤2:启用Flash Attention后端
- [ ] 步骤3:通过性能分析验证速度提升
- [ ] 步骤4:测试精度是否与基准一致步骤1:检查PyTorch版本
bash
python -c "import torch; print(torch.__version__)"Should be ≥2.2.0
Should be ≥2.2.0
If <2.2, upgrade:
```bash
pip install --upgrade torchStep 2: Enable Flash Attention backend
Replace standard attention:
python
undefined
如果版本<2.2,进行升级:
```bash
pip install --upgrade torch步骤2:启用Flash Attention后端
替换标准注意力实现:
python
undefinedBefore (standard attention)
Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
After (Flash Attention)
After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
Force Flash Attention backend:
```python
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)Step 3: Verify speedup with profiling
python
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ vimport torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
强制使用Flash Attention后端:
```python
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)步骤3:通过性能分析验证速度提升
python
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ vBenchmark
Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
Expected: 2-4x speedup for sequences >512 tokens.
**Step 4: Test accuracy matches baseline**
```pythont_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
预期结果:对于序列长度>512的token,可实现2-4倍的速度提升。
**步骤4:测试精度是否与基准一致**
```pythonCompare outputs
Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
Flash Attention
Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
out_flash = F.scaled_dot_product_attention(q, k, v)
Standard attention
Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
Check difference
Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
Should be <1e-3 for float16
Should be <1e-3 for float16
undefinedundefinedWorkflow 2: Use flash-attn library for advanced features
工作流2:使用flash-attn库实现高级功能
For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performanceStep 1: Install flash-attn library
bash
undefined适用于多查询注意力、滑动窗口注意力或H100 FP8优化场景。
复制以下检查清单:
flash-attn库设置:
- [ ] 步骤1:安装flash-attn库
- [ ] 步骤2:修改注意力代码
- [ ] 步骤3:启用高级功能
- [ ] 步骤4:进行性能基准测试步骤1:安装flash-attn库
bash
undefinedNVIDIA GPUs (CUDA 12.0+)
NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation
pip install flash-attn --no-build-isolation
Verify installation
Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
**Step 2: Modify attention code**
```python
from flash_attn import flash_attn_funcpython -c "from flash_attn import flash_attn_func; print('Success')"
**步骤2:修改注意力代码**
```python
from flash_attn import flash_attn_funcInput: [batch_size, seq_len, num_heads, head_dim]
Input: [batch_size, seq_len, num_heads, head_dim]
Transpose from [batch, heads, seq, dim] if needed
Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
**Step 3: Enable advanced features**
Multi-query attention (shared K/V across heads):
```python
from flash_attn import flash_attn_funcq = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
**步骤3:启用高级功能**
多查询注意力(所有注意力头共享K/V):
```python
from flash_attn import flash_attn_funcq: [batch, seq, num_q_heads, dim]
q: [batch, seq, num_q_heads, dim]
k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA
Sliding window attention (local attention):
```pythonout = flash_attn_func(q, k, v) # Automatically handles MQA
滑动窗口注意力(局部注意力):
```pythonOnly attend to window of 256 tokens before/after
Only attend to window of 256 tokens before/after
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)
**Step 4: Benchmark performance**
```python
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)
**步骤4:进行性能基准测试**
```python
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]Warmup
Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v)
for _ in range(10):
_ = flash_attn_func(q, k, v)
Benchmark
Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
undefinedtorch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
undefinedWorkflow 3: H100 FP8 optimization (FlashAttention-3)
工作流3:H100 FP8优化(FlashAttention-3)
For maximum performance on H100 GPUs.
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attentionStep 1: Verify H100 GPU
bash
nvidia-smi --query-gpu=name --format=csv在H100 GPU上实现极致性能。
FP8设置:
- [ ] 步骤1:确认H100 GPU可用
- [ ] 步骤2:安装支持FP8的flash-attn库
- [ ] 步骤3:将输入转换为FP8格式
- [ ] 步骤4:运行FP8注意力计算步骤1:确认H100 GPU可用
bash
nvidia-smi --query-gpu=name --format=csvShould show "H100" or "H800"
Should show "H100" or "H800"
**Step 2: Install flash-attn with FP8 support**
```bash
pip install flash-attn --no-build-isolation
**步骤2:安装支持FP8的flash-attn库**
```bash
pip install flash-attn --no-build-isolationFP8 support included for H100
FP8 support included for H100
**Step 3: Convert inputs to FP8**
```python
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
**步骤3:将输入转换为FP8格式**
```python
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)Convert to float8_e4m3 (FP8)
Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
**Step 4: Run with FP8 attention**
```python
from flash_attn import flash_attn_funcq_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
**步骤4:运行FP8注意力计算**
```python
from flash_attn import flash_attn_funcFlashAttention-3 automatically uses FP8 kernels on H100
FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
undefinedundefinedWhen to use vs alternatives
使用场景与替代方案对比
Use Flash Attention when:
- Training transformers with sequences >512 tokens
- Running inference with long context (>2K tokens)
- GPU memory constrained (OOM with standard attention)
- Need 2-4x speedup without accuracy loss
- Using PyTorch 2.2+ or can install flash-attn
Use alternatives instead:
- Standard attention: Sequences <256 tokens (overhead not worth it)
- xFormers: Need more attention variants (not just speed)
- Memory-efficient attention: CPU inference (Flash Attention needs GPU)
推荐使用Flash Attention的场景:
- 训练长序列(>512个token)的Transformer模型
- 运行长上下文(>2K个token)的推理任务
- GPU内存受限(标准注意力机制出现OOM错误)
- 需要在不损失精度的前提下实现2-4倍的速度提升
- 使用PyTorch 2.2+版本或可安装flash-attn库
推荐使用替代方案的场景:
- 标准注意力机制: 序列长度<256个token(Flash Attention的开销得不偿失)
- xFormers: 需要更多注意力变体(而非仅追求速度)
- 内存高效注意力: CPU推理场景(Flash Attention依赖GPU)
Common issues
常见问题
Issue: ImportError: cannot import flash_attn
Install with no-build-isolation flag:
bash
pip install flash-attn --no-build-isolationOr install CUDA toolkit first:
bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolationIssue: Slower than expected (no speedup)
Flash Attention benefits increase with sequence length:
- <512 tokens: Minimal speedup (10-20%)
- 512-2K tokens: 2-3x speedup
-
2K tokens: 3-4x speedup
Check sequence length is sufficient.
Issue: RuntimeError: CUDA error
Verify GPU supports Flash Attention:
python
import torch
print(torch.cuda.get_device_capability())问题:ImportError: cannot import flash_attn
使用no-build-isolation参数重新安装:
bash
pip install flash-attn --no-build-isolation或先安装CUDA工具包:
bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation问题:速度未达预期(无明显提升)
Flash Attention的性能收益随序列长度增加而提升:
- <512个token: 提升幅度极小(10-20%)
- 512-2K个token: 2-3倍速度提升
-
2K个token: 3-4倍速度提升
请检查序列长度是否足够。
问题:RuntimeError: CUDA error
验证GPU是否支持Flash Attention:
python
import torch
print(torch.cuda.get_device_capability())Should be ≥(7, 5) for Turing+
Should be ≥(7, 5) for Turing+
Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported
**Issue: Accuracy degradation**
Check dtype is float16 or bfloat16 (not float32):
```python
q = q.to(torch.float16) # Or torch.bfloat16Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
Flash Attention支持以下GPU:
- Ampere架构(A100、A10、A30): ✅ 完全支持
- Turing架构(T4): ✅ 支持
- Volta架构(V100): ❌ 不支持
**问题:精度下降**
检查数据类型是否为float16或bfloat16(不支持float32):
```python
q = q.to(torch.float16) # Or torch.bfloat16Flash Attention为追求速度使用float16/bfloat16数据类型,不支持float32。
Advanced topics
高级主题
Integration with HuggingFace Transformers: See references/transformers-integration.md for enabling Flash Attention in BERT, GPT, Llama models.
Performance benchmarks: See references/benchmarks.md for detailed speed and memory comparisons across GPUs and sequence lengths.
Algorithm details: See references/algorithm.md for tiling strategy, recomputation, and IO complexity analysis.
Advanced features: See references/advanced-features.md for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
与HuggingFace Transformers集成: 参考references/transformers-integration.md了解如何在BERT、GPT、Llama模型中启用Flash Attention。
性能基准测试: 参考references/benchmarks.md查看不同GPU和序列长度下的详细速度与内存对比数据。
算法细节: 参考references/algorithm.md了解分块策略、重计算技术以及IO复杂度分析。
高级功能: 参考references/advanced-features.md了解旋转位置编码、ALiBi、分页KV缓存以及自定义注意力掩码等功能。
Hardware requirements
硬件要求
- GPU: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
- VRAM: Same as standard attention (Flash Attention doesn't increase memory)
- CUDA: 12.0+ (11.8 minimum)
- PyTorch: 2.2+ for native support
Not supported: V100 (Volta), CPU inference
- GPU: NVIDIA Ampere+(A100、A10、A30)或AMD MI200+
- VRAM: 与标准注意力机制要求相同(Flash Attention不会增加内存占用)
- CUDA: 12.0+(最低要求11.8)
- PyTorch: 2.2+(原生支持)
不支持: V100(Volta架构)、CPU推理
Resources
参考资源
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
- Blog: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
- 论文: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
- 论文: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
- 博客: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch文档: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html ",