mla-analysis

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

MLA Cost Analysis & Regime Guide

MLA成本分析与场景指南

Regime Selection

场景选择

Regimes RangeBest KernelWhy
Decodes=1FlashMLA16x latency reduction vs FlashAttention (compressed KV)
Speculatives=2-32MLAvar6+ or FlashMLAMLAvar6+ should be able to beat FlashMLA and FlashAttention
Prefills>128FlashAttentionAvoids 4x FLOP penalty of latent-space compute
Crossover point: FlashAttention becomes faster than FlashMLA at approximately s=16-32 for DeepSeek-V3 parameters.
场景s范围最优内核原因
解码s=1FlashMLA相比FlashAttention(压缩KV)延迟降低16倍
推测性解码s=2-32MLAvar6+ 或 FlashMLAMLAvar6+ 性能应优于FlashMLA和FlashAttention
预填充s>128FlashAttention避免潜在空间计算带来的4倍FLOPs开销
交叉点:对于DeepSeek-V3参数,当s约为16-32时,FlashAttention的速度会超过FlashMLA。

Cost Models (DeepSeek-V3-like: h=128, d=128, k=512, p=64)

成本模型(类似DeepSeek-V3:h=128, d=128, k=512, p=64)

FlashAttention

FlashAttention

  • FLOPs:
    2bhst(2d + p)
    =
    2bhst * 320
  • Bytes:
    w * bh(s+t)(2d + p)
    =
    w * bh(s+t) * 320
  • At s=1: AI ≈ 1 FLOP/byte (deeply memory-bound)
  • At s=1024: AI ≈ 819 FLOP/byte (deeply compute-bound)
  • FLOPs:
    2bhst(2d + p)
    =
    2bhst * 320
  • 字节数:
    w * bh(s+t)(2d + p)
    =
    w * bh(s+t) * 320
  • 当s=1时:AI ≈ 1 FLOP/字节(深度内存受限)
  • 当s=1024时:AI ≈ 819 FLOP/字节(深度计算受限)

FlashMLA (latent-space attention via absorption)

FlashMLA(基于吸收技巧的潜在空间注意力)

  • FLOPs:
    2bhst(2k + p)
    =
    2bhst * 1088
  • Bytes:
    w * (bhs(2k+p) + bt(k+p))
  • At s=1: AI ≈ 228 FLOP/byte (compute-bound even at decode)
  • Problem: FLOPs grow linearly with s at k=512 instead of d=128 — 3.4x more FLOPs per token
  • FLOPs:
    2bhst(2k + p)
    =
    2bhst * 1088
  • 字节数:
    w * (bhs(2k+p) + bt(k+p))
  • 当s=1时:AI ≈ 228 FLOP/字节(即使在解码场景下也受计算限制)
  • 问题:当k=512时,FLOPs随s线性增长,而不是d=128——每个token的FLOPs增加3.4倍

MLAvar6+ (split latent + decompressed)

MLAvar6+(拆分潜在空间+解压缩)

  • FLOPs:
    2bhstp + 4bhsnd + 4bhsok
  • Bytes:
    w * (bhsp + bhsd + bhsk + bok + btp + bhnd + bhnd + bhsd + bhsk)
  • Key: choosing n tunes operational intensity near roofline ridge point
  • FLOPs:
    2bhstp + 4bhsnd + 4bhsok
  • 字节数:
    w * (bhsp + bhsd + bhsk + bok + btp + bhnd + bhnd + bhsd + bhsk)
  • 关键:选择n值可将运算强度调整至接近roofline脊点

Absorption Trick (FlashMLA)

吸收技巧(FlashMLA)

Score:  Q @ K^T = Q @ (Z @ Wk)^T = (Q @ Wk^T) @ Z^T = Qz @ Z^T
Value:  softmax(A) @ V = softmax(A) @ (Z @ Wv) = (softmax(A) @ Z) @ Wv
Operate entirely in latent space (k-dim), with output decompression
O_latent @ W_kvb2^T
done by a separate kernel.
Score:  Q @ K^T = Q @ (Z @ Wk)^T = (Q @ Wk^T) @ Z^T = Qz @ Z^T
Value:  softmax(A) @ V = softmax(A) @ (Z @ Wv) = (softmax(A) @ Z) @ Wv
完全在潜在空间(k维度)中运算,输出解压缩
O_latent @ W_kvb2^T
由单独的内核完成。

MLAvar6+ Design

MLAvar6+设计

Split KV cache into:
  • o old tokens: latent Z (b, o, k) — compute-heavy, low bandwidth
  • n new tokens: decompressed K,V (b, n, h, d) — bandwidth-heavy, low compute
  • t = o + n
Interpolates: n=0 → FlashMLA-like, o=0 → FlashAttention-like.
将KV缓存拆分为:
  • o个旧token:潜在空间Z (b, o, k) —— 计算密集型,低带宽
  • n个新token:解压缩后的K,V (b, n, h, d) —— 带宽密集型,低计算量
  • t = o + n
可插值调整:n=0 → 类似FlashMLA,o=0 → 类似FlashAttention。

Historical Baselines (RTX 5090, b=32, t=4096, bfloat16)

历史基准测试(RTX 5090, b=32, t=4096, bfloat16)

Note: These results are from RTX 5090 development. Reprofile on the current device before using as optimization targets. Ridge point, bandwidth, and compute ceilings differ across devices — see
src/mla_var3/conf/devices.json
.
ConfigFlashMLAFlashAttentionMLAvar6+ V3 (best)
s=1419 μs6,781 μs829 μs (V2)
s=165,161 μs6,727 μs4,444 μs
MLAvar6+ V3 is the current best for speculative decoding (s=16), beating FlashMLA by 14%.
注意:这些结果来自RTX 5090开发阶段。将其用作优化目标前,请在当前设备上重新测试。不同设备的脊点、带宽和计算上限有所不同——详见
src/mla_var3/conf/devices.json
配置FlashMLAFlashAttentionMLAvar6+ V3(最优)
s=1419 μs6,781 μs829 μs(V2版本)
s=165,161 μs6,727 μs4,444 μs
MLAvar6+ V3是当前推测性解码(s=16)场景下的最优选择,比FlashMLA快14%。

Tensor Shapes (DeepSeek-V3 defaults)

张量形状(DeepSeek-V3默认值)

SymbolDescriptionDefault
bBatch size64
hNumber of heads128
sQuery sequence lengthvaries
tKV context length4096
dHead dimension128
pPositional embedding dim64
kLatent (compressed) dim512
Tensor layouts:
  • Q, Qc:
    [B, H, S, D]
    or
    [B, H, S, K]
  • K, V:
    [B, T, H, D]
    (decompressed)
  • Z:
    [B, T, K]
    (latent KV cache)
  • QPE, KPE:
    [B, H, S, P]
    or
    [B, T, P]
符号描述默认值
b批次大小64
h注意力头数量128
s查询序列长度可变
tKV上下文长度4096
d注意力头维度128
p位置嵌入维度64
k潜在(压缩)维度512
张量布局:
  • Q, Qc:
    [B, H, S, D]
    [B, H, S, K]
  • K, V:
    [B, T, H, D]
    (解压缩后)
  • Z:
    [B, T, K]
    (潜在KV缓存)
  • QPE, KPE:
    [B, H, S, P]
    [B, T, P]

Detailed Analysis

详细分析

Full cost model derivations and roofline analysis:
docs/cost-analysis.md
完整的成本模型推导与roofline分析:
docs/cost-analysis.md