2024年8月21日

Transformer 改进与注意力优化前沿

Transformer 注意力优化的核心目标是降低 $O(n^2)$ 的计算和内存复杂度,同时保持模型质量,主要从推理效率(KV Cache 压缩)和训练效率(硬件感知计算)两个维度展开。

知识库大模型基础原理llmmodeltransformerattentionoptimization

先说结论

Transformer 注意力优化的核心目标是降低 O(n2)O(n^2) 的计算和内存复杂度,同时保持模型质量,主要从推理效率(KV Cache 压缩)和训练效率(硬件感知计算)两个维度展开。

注意力机制优化

Multi-Query Attention (MQA)

  • 核心思想:所有 query head 共享一个 key-value head
  • 优势:极大减少 KV Cache 大小,推理速度快
  • 代价:精度略有下降
  • 使用模型:PaLM, Falcon, StarCoder
  • 原理:标准 MHA 有 hh 组 KV,MQA 只有 1 组。KV Cache 从 2×h×dhead2 \times h \times d_{head} 降为 2×dhead2 \times d_{head},减少 hh

QUESTION MQA 为什么能加速推理? 推理的瓶颈往往不在计算而在内存带宽。KV Cache 越小,每步从 HBM 读取的数据越少,推理越快。MQA 将 KV Cache 减少了 hh 倍(hh 为注意力头数),显著降低内存带宽压力。

Grouped Query Attention (GQA)

  • 核心思想:介于 MHA 和 MQA 之间,将 query head 分组,每组共享一个 KV head
  • 优势:在推理速度和模型质量之间取得平衡
  • 使用模型:LLaMA 2/3, Mistral, Qwen 2
  • 关键参数:group 大小(通常 KV head 数为 8)

QUESTION GQA 与 MHA、MQA 的区别?

方法 KV Head 数 KV Cache 大小 模型质量 推理速度
MHA hh(与 Q 头数相同) 最大 最好 最慢
GQA gg1<g<h1 < g < h 中等 接近 MHA 中等
MQA 1 最小 略低于 MHA 最快

GQA 可以从已训练的 MHA 模型转换而来:将 Q 头分组后,每组内取 KV 头的均值作为共享 KV 头。

Multi-head Latent Attention (MLA)

  • 核心思想:将 KV 压缩到低维潜在表示,推理时只需缓存压缩后的向量
  • 使用模型:DeepSeek-V2/V3
  • 优势:比 GQA 更极端的 KV Cache 压缩,推理效率极高
  • 原理:通过低秩投影 cKV=WDKVhtc_{KV} = W_{DKV} \cdot h_t 将 KV 压缩到低维空间 cc,推理时只缓存 cKVc_{KV} 而非完整 KV,需要时通过上投影恢复

QUESTION MLA 与 GQA 压缩 KV Cache 的方式有何不同? GQA 通过"减少头数"压缩——每组 Q 头共享一个完整的 KV 头,是离散的、结构化的压缩。MLA 通过"低秩投影"压缩——将 KV 投影到低维潜在空间,是连续的、信息保留度更高的压缩。MLA 的压缩率更高,且恢复后的信息损失更小。

FlashAttention 2/3

  • 核心思想:硬件感知的注意力实现,优化 GPU 内存访问模式
  • 效果:接近理论峰值 FLOP 利用率
  • 关键创新:将注意力计算的内存从 O(n2)O(n^2) 降低到 O(n)O(n)
  • FlashAttention-3:针对 Hopper GPU (H100) 进一步优化
  • 影响:几乎所有主流 LLM 都使用 FlashAttention

QUESTION FlashAttention 为什么快?它做近似了吗? FlashAttention 是精确计算,不是近似。加速来自:(1) 分块计算(Tiling):将 Q/K/V 分成小块,在 SRAM(片上快速缓存)中完成注意力计算,避免将 n×nn \times n 的完整矩阵写入 HBM;(2) 在线 Softmax:逐块更新 softmax 结果,无需一次性加载所有分数;(3) 减少了对 HBM 的读写次数(从 O(n2)O(n^2) 降到 O(n)O(n)),而 HBM 读写是 GPU 计算的主要瓶颈。

状态空间模型 (SSM)

Mamba

  • 核心思想:选择性状态空间模型,输入依赖的选择机制
  • 优势:线性复杂度序列建模
  • 局限:在需要精确回忆的任务上不如 Transformer

Mamba-2

  • 引入结构化状态空间对偶性 (SSD)
  • 展示 SSM 和注意力之间的联系
  • 更快的训练和推理

混合架构

Jamba (AI21 Labs, 2024)

  • Mamba 层 + Transformer 注意力层 + MoE 混合架构
  • 长上下文处理效率高
  • 代表了"混合架构"的趋势

长上下文优化

位置编码改进

方法 说明 使用模型
RoPE 旋转位置编码,捕获相对位置 LLaMA, PaLM, Qwen
ALiBi 线性偏置注意力,支持长度外推 BLOOM, MPT
YaRN RoPE 的改进外推方法 开源社区

上下文压缩

  • Sliding Window Attention:只关注局部窗口(Mistral),复杂度从 O(n2)O(n^2) 降到 O(nw)O(n \cdot w)ww 为窗口大小
  • Ring Attention:跨设备分布式注意力,将序列分块分配到不同设备,理论上支持无限长序列
  • KV Cache 剪枝:H2O、Scissorhands 等方法剪掉不重要的 KV,动态管理缓存

稀疏与高效注意力

稀疏注意力模式

  • 局部窗口 + 全局 token 的混合模式
  • Strided attention(Longformer)
  • 可学习稀疏模式

线性注意力

  • 用核函数近似替代 softmax:softmax(qk)ϕ(q)ϕ(k)softmax(q \cdot k) \approx \phi(q) \cdot \phi(k)
  • 复杂度从 O(n2d)O(n^2 d) 降到 O(nd2)O(nd^2)
  • 利用结合律先计算 ϕ(K)TV\phi(K)^T Vd×dd \times d 矩阵),再乘 QQ
  • 质量通常不如标准注意力

QUESTION 线性注意力为什么质量不如标准注意力? 核函数近似 ϕ(q)ϕ(k)\phi(q) \cdot \phi(k) 无法精确代替 softmax 的归一化和指数非线性。softmax 的"赢者通吃"效应(少量高相似度对获得大部分权重)对精确信息检索至关重要,线性近似的均匀分配倾向导致信息模糊。

模型架构趋势总结

技术 状态 优先级
GQA 已成标配 重点读
FlashAttention 2/3 已成标配 重点读
MLA (DeepSeek) 前沿 重点读
RoPE 已成标配 重点读
Mamba/SSM 活跃研究 了解即可
混合架构 趋势 了解即可
线性注意力 研究阶段 暂不深挖

延伸阅读

参考资料

  • Dao, T. et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022.
  • Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150 (MQA)
  • Ainslie, J. et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023.
  • DeepSeek-V2 Technical Report (2024). MLA 的原始论文.