mla-analysis
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseMLA Cost Analysis & Regime Guide
MLA成本分析与场景指南
Regime Selection
场景选择
| Regime | s Range | Best Kernel | Why |
|---|---|---|---|
| Decode | s=1 | FlashMLA | 16x latency reduction vs FlashAttention (compressed KV) |
| Speculative | s=2-32 | MLAvar6+ or FlashMLA | MLAvar6+ should be able to beat FlashMLA and FlashAttention |
| Prefill | s>128 | FlashAttention | Avoids 4x FLOP penalty of latent-space compute |
Crossover point: FlashAttention becomes faster than FlashMLA at approximately s=16-32 for DeepSeek-V3 parameters.
| 场景 | s范围 | 最优内核 | 原因 |
|---|---|---|---|
| 解码 | s=1 | FlashMLA | 相比FlashAttention(压缩KV)延迟降低16倍 |
| 推测性解码 | s=2-32 | MLAvar6+ 或 FlashMLA | MLAvar6+ 性能应优于FlashMLA和FlashAttention |
| 预填充 | s>128 | FlashAttention | 避免潜在空间计算带来的4倍FLOPs开销 |
交叉点:对于DeepSeek-V3参数,当s约为16-32时,FlashAttention的速度会超过FlashMLA。
Cost Models (DeepSeek-V3-like: h=128, d=128, k=512, p=64)
成本模型(类似DeepSeek-V3:h=128, d=128, k=512, p=64)
FlashAttention
FlashAttention
- FLOPs: =
2bhst(2d + p)2bhst * 320 - Bytes: =
w * bh(s+t)(2d + p)w * bh(s+t) * 320 - At s=1: AI ≈ 1 FLOP/byte (deeply memory-bound)
- At s=1024: AI ≈ 819 FLOP/byte (deeply compute-bound)
- FLOPs: =
2bhst(2d + p)2bhst * 320 - 字节数: =
w * bh(s+t)(2d + p)w * bh(s+t) * 320 - 当s=1时:AI ≈ 1 FLOP/字节(深度内存受限)
- 当s=1024时:AI ≈ 819 FLOP/字节(深度计算受限)
FlashMLA (latent-space attention via absorption)
FlashMLA(基于吸收技巧的潜在空间注意力)
- FLOPs: =
2bhst(2k + p)2bhst * 1088 - Bytes:
w * (bhs(2k+p) + bt(k+p)) - At s=1: AI ≈ 228 FLOP/byte (compute-bound even at decode)
- Problem: FLOPs grow linearly with s at k=512 instead of d=128 — 3.4x more FLOPs per token
- FLOPs: =
2bhst(2k + p)2bhst * 1088 - 字节数:
w * (bhs(2k+p) + bt(k+p)) - 当s=1时:AI ≈ 228 FLOP/字节(即使在解码场景下也受计算限制)
- 问题:当k=512时,FLOPs随s线性增长,而不是d=128——每个token的FLOPs增加3.4倍
MLAvar6+ (split latent + decompressed)
MLAvar6+(拆分潜在空间+解压缩)
- FLOPs:
2bhstp + 4bhsnd + 4bhsok - Bytes:
w * (bhsp + bhsd + bhsk + bok + btp + bhnd + bhnd + bhsd + bhsk) - Key: choosing n tunes operational intensity near roofline ridge point
- FLOPs:
2bhstp + 4bhsnd + 4bhsok - 字节数:
w * (bhsp + bhsd + bhsk + bok + btp + bhnd + bhnd + bhsd + bhsk) - 关键:选择n值可将运算强度调整至接近roofline脊点
Absorption Trick (FlashMLA)
吸收技巧(FlashMLA)
Score: Q @ K^T = Q @ (Z @ Wk)^T = (Q @ Wk^T) @ Z^T = Qz @ Z^T
Value: softmax(A) @ V = softmax(A) @ (Z @ Wv) = (softmax(A) @ Z) @ WvOperate entirely in latent space (k-dim), with output decompression done by a separate kernel.
O_latent @ W_kvb2^TScore: Q @ K^T = Q @ (Z @ Wk)^T = (Q @ Wk^T) @ Z^T = Qz @ Z^T
Value: softmax(A) @ V = softmax(A) @ (Z @ Wv) = (softmax(A) @ Z) @ Wv完全在潜在空间(k维度)中运算,输出解压缩由单独的内核完成。
O_latent @ W_kvb2^TMLAvar6+ Design
MLAvar6+设计
Split KV cache into:
- o old tokens: latent Z (b, o, k) — compute-heavy, low bandwidth
- n new tokens: decompressed K,V (b, n, h, d) — bandwidth-heavy, low compute
- t = o + n
Interpolates: n=0 → FlashMLA-like, o=0 → FlashAttention-like.
将KV缓存拆分为:
- o个旧token:潜在空间Z (b, o, k) —— 计算密集型,低带宽
- n个新token:解压缩后的K,V (b, n, h, d) —— 带宽密集型,低计算量
- t = o + n
可插值调整:n=0 → 类似FlashMLA,o=0 → 类似FlashAttention。
Historical Baselines (RTX 5090, b=32, t=4096, bfloat16)
历史基准测试(RTX 5090, b=32, t=4096, bfloat16)
Note: These results are from RTX 5090 development. Reprofile on the current device before using as optimization targets. Ridge point, bandwidth, and compute ceilings differ across devices — see.src/mla_var3/conf/devices.json
| Config | FlashMLA | FlashAttention | MLAvar6+ V3 (best) |
|---|---|---|---|
| s=1 | 419 μs | 6,781 μs | 829 μs (V2) |
| s=16 | 5,161 μs | 6,727 μs | 4,444 μs |
MLAvar6+ V3 is the current best for speculative decoding (s=16), beating FlashMLA by 14%.
注意:这些结果来自RTX 5090开发阶段。将其用作优化目标前,请在当前设备上重新测试。不同设备的脊点、带宽和计算上限有所不同——详见。src/mla_var3/conf/devices.json
| 配置 | FlashMLA | FlashAttention | MLAvar6+ V3(最优) |
|---|---|---|---|
| s=1 | 419 μs | 6,781 μs | 829 μs(V2版本) |
| s=16 | 5,161 μs | 6,727 μs | 4,444 μs |
MLAvar6+ V3是当前推测性解码(s=16)场景下的最优选择,比FlashMLA快14%。
Tensor Shapes (DeepSeek-V3 defaults)
张量形状(DeepSeek-V3默认值)
| Symbol | Description | Default |
|---|---|---|
| b | Batch size | 64 |
| h | Number of heads | 128 |
| s | Query sequence length | varies |
| t | KV context length | 4096 |
| d | Head dimension | 128 |
| p | Positional embedding dim | 64 |
| k | Latent (compressed) dim | 512 |
Tensor layouts:
- Q, Qc: or
[B, H, S, D][B, H, S, K] - K, V: (decompressed)
[B, T, H, D] - Z: (latent KV cache)
[B, T, K] - QPE, KPE: or
[B, H, S, P][B, T, P]
| 符号 | 描述 | 默认值 |
|---|---|---|
| b | 批次大小 | 64 |
| h | 注意力头数量 | 128 |
| s | 查询序列长度 | 可变 |
| t | KV上下文长度 | 4096 |
| d | 注意力头维度 | 128 |
| p | 位置嵌入维度 | 64 |
| k | 潜在(压缩)维度 | 512 |
张量布局:
- Q, Qc: 或
[B, H, S, D][B, H, S, K] - K, V: (解压缩后)
[B, T, H, D] - Z: (潜在KV缓存)
[B, T, K] - QPE, KPE: 或
[B, H, S, P][B, T, P]
Detailed Analysis
详细分析
Full cost model derivations and roofline analysis:
docs/cost-analysis.md完整的成本模型推导与roofline分析:
docs/cost-analysis.md