2026年4月19日

FlashAttention 与投机解码

标准自注意力计算:

知识库大模型推理与系统

FlashAttention 与投机解码

1. FlashAttention 概述

问题:标准注意力的内存瓶颈

标准自注意力计算: Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

需要计算并存储完整的 N×NN \times N 注意力矩阵(NN 为序列长度),这对 GPU 内存层次结构极不友好。

GPU 内存层次

层级 类型 容量 带宽
HBM (高带宽内存) 主显存 40-80 GB 1-2 TB/s
SRAM (片上共享内存) 每个 SM 的缓存 ~192 KB/SM ~19 TB/s

关键洞察:标准注意力的瓶颈不是计算(FLOPs),而是内存 I/O —— 反复在 HBM 中读写巨大的 N×NN \times N 注意力矩阵。

QUESTION 面试题:FlashAttention 解决的核心问题是什么? 标准注意力需要在 HBM 中分配并读写 O(N2)O(N^2) 大小的中间矩阵(S = QK^T 和 P = softmax(S)),这在长序列下产生巨大的内存 I/O 开销。FlashAttention 的核心是避免在 HBM 中存储完整注意力矩阵,通过分块计算(Tiling)将中间结果保留在高速 SRAM 中,从而将 HBM 访问量从 O(N2)O(N^2) 降低到 O(N2d2/M)O(N^2d^2/M)(其中 MM 为 SRAM 大小)。

2. FlashAttention 原理

核心思想:分块计算(Tiling)

将 Q、K、V 矩阵切分成能放入 SRAM 的小块(Tile),在片上完成注意力计算,避免在 HBM 中存储完整注意力矩阵。

算法步骤

1. 将 Q, K, V 切分为块 Q_i, K_j, V_j(大小适配 SRAM)
2. For each Q_i:
     For each K_j, V_j:
       a. 在 SRAM 中计算 S_ij = Q_i × K_j^T / √d
       b. 使用 Online Softmax 累积:
          - 获取当前块的局部 max 和 sum
          - 更新全局的 running max 和 running sum
          - 修正之前的累积结果
       c. 计算 O_ij = softmax(S_ij) × V_j
       d. 累积到输出 O_i
3. 将最终输出 O 写回 HBM

Online Softmax Trick

传统 Softmax 需要两遍扫描(先求 max,再计算),FlashAttention 使用增量式 Softmaxmnew=max(mold,mlocal)m_{\text{new}} = \max(m_{\text{old}}, m_{\text{local}}) lnew=emoldmnewlold+emlocalmnewllocall_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} \cdot l_{\text{old}} + e^{m_{\text{local}} - m_{\text{new}}} \cdot l_{\text{local}} Onew=loldemoldmnewOold+llocalemlocalmnewOlocallnewO_{\text{new}} = \frac{l_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}} \cdot O_{\text{old}} + l_{\text{local}} \cdot e^{m_{\text{local}} - m_{\text{new}}} \cdot O_{\text{local}}}{l_{\text{new}}}

IO 复杂度对比

算法 HBM 读取 HBM 写入
标准注意力 O(Nd+N2)O(Nd + N^2) O(N2+Nd)O(N^2 + Nd)
FlashAttention O(N2d2/M)O(N^2d^2/M) O(Nd)O(Nd)

其中 MM 为 SRAM 大小。实际效果:2-4× 加速10-20× 内存节省

3. FlashAttention-2

改进点

  1. 更好的工作划分:减少非矩阵乘法操作(GPU 上 matmul 比 non-matmul 快约 16×)
  2. 序列长度维度并行化:v1 只在 batch 和 head 维度并行,v2 增加了序列维度
  3. 优化的线程/warp 分配:减少 warp 间的同步等待

性能

  • 在 A100 上达到理论峰值 TFLOPs 的 50-73%
  • 比标准注意力快 2-4×
  • 已集成到 PyTorch F.scaled_dot_product_attention

4. FlashAttention-3

面向 Hopper GPU (H100) 的优化

  1. 异步执行:利用 H100 的 TMA(Tensor Memory Accelerator)异步加载数据
  2. Warp 特化:部分 warp 专门负责加载(Producer),部分专门负责计算(Consumer)
  3. FP8 支持:利用 FP8 Tensor Core 获得更高吞吐
  4. Softmax 与 GEMM 重叠:计算与内存操作流水线化

性能

  • 在 H100 上达到理论峰值的 75%+
  • 比 FlashAttention-2 在 H100 上快 1.5-2×

QUESTION 面试题:FlashAttention 各版本的核心演进是什么?

版本 核心改进 性能
FlashAttention-1 分块计算 + Online Softmax 2-4× 加速
FlashAttention-2 序列维度并行 + 减少 non-matmul A100 峰值 50-73%
FlashAttention-3 Hopper TMA + FP8 + Producer-Consumer H100 峰值 75%+

核心演进路线:减少 HBM I/O → 更好的并行 → 硬件特化优化。每一代都在提高 GPU 计算单元的利用率。

5. 推理加速技术全景

LLM 推理加速技术可从计算优化访存优化系统优化三个维度分类:

计算优化

技术 原理 加速效果 适用阶段
FlashAttention 分块计算减少 HBM I/O 2-4× Prefill + Decode
算子融合 (Kernel Fusion) 多个 CUDA 算子合并为一个 1.2-2× Prefill + Decode
投机解码 Draft Model 并行验证多 Token 1.5-3× Decode
Tensor Core 优化 利用 GPU 专用矩阵单元 2-8× Prefill

访存优化

技术 原理 效果 适用阶段
KV Cache 缓存历史 KV 避免重复计算 数十倍加速 Decode
GQA / MQA 减少 KV 头数 KV Cache 减少 4-32× Decode
KV Cache 量化 FP16→FP8/INT4 显存减半至 1/4 Decode
PagedAttention 分页管理消除碎片 利用率 20-40%→>95% 全阶段

系统优化

技术 原理 效果 适用场景
Continuous Batching 迭代级调度 吞吐 2-4× 在线服务
Chunked Prefill 长 Prompt 分块处理 减少阻塞 长上下文
Prefix Caching 共享前缀 KV Cache 减少重复计算 System Prompt
CPU Offloading KV Cache 换出到 CPU 支持更大并发 显存不足

QUESTION 面试题:LLM 推理有哪些加速方法?请分类说明。 推理加速可从三个维度分类:

  1. 计算优化:FlashAttention(分块计算)、算子融合、投机解码、Tensor Core 利用
  2. 访存优化:KV Cache、GQA/MQA(减少 KV 头数)、KV Cache 量化、PagedAttention
  3. 系统优化:Continuous Batching、Chunked Prefill、Prefix Caching、CPU Offloading、量化推理

选择的依据:Prefill 瓶颈在计算(选计算优化),Decode 瓶颈在访存(选访存优化),在线服务需高吞吐(选系统优化)。

6. 投机解码(Speculative Decoding)

核心思想

自回归生成的瓶颈是逐 Token 串行生成,每次只产生一个 Token。投机解码通过引入一个小型草稿模型(Draft Model) 来并行验证多个候选 Token,加速生成。

工作流程

Step 1: Draft 阶段
┌─────────────────────────────────────────┐
│ Draft Model (小型,快速)                    │
│ 自回归生成 K 个候选 Token: [t1, t2, t3, t4, t5] │
└─────────────────────────────────────────┘
                    ↓
Step 2: Verify 阶段
┌─────────────────────────────────────────┐
│ Target Model (大型,目标模型)                │
│ 单次前向传播处理所有 K 个 Token               │
│ 获取每个位置的概率分布                        │
│ 与 Draft Token 逐一比较                     │
└─────────────────────────────────────────┘
                    ↓
Step 3: 接受/拒绝
┌─────────────────────────────────────────┐
│ t1: ✓ 接受 (概率匹配)                      │
│ t2: ✓ 接受                                │
│ t3: ✓ 接受                                │
│ t4: ✗ 拒绝 → 使用 Target 的分布采样         │
│ t5: 跳过 (因为 t4 已拒绝)                   │
│ 结果: 接受了 3 个 Token + 1 个新 Token = 4  │
└─────────────────────────────────────────┘

接受概率

Token 被接受的概率: P(accept)=min(1,ptarget(t)qdraft(t))P(\text{accept}) = \min\left(1, \frac{p_{\text{target}}(t)}{q_{\text{draft}}(t)}\right)

当 Draft 模型质量好时(qdraftptargetq_{\text{draft}} \approx p_{\text{target}}),大多数 Token 都会被接受。

数学保证

投机解码的输出分布与标准自回归解码完全相同。接受-拒绝机制基于修改的拒绝采样,确保零质量损失

加速效果

草稿 Token 数 K 接受率 实际加速比
5 80% ~2.5×
5 60% ~1.8×
10 80% ~3.5×
10 60% ~2.2×

典型实践:1.5×-3× 加速(单批次推理)

QUESTION 面试题:投机解码为什么能保证输出质量不变? 投机解码使用修改的拒绝采样(Modified Rejection Sampling):对于 Draft 模型生成的每个 Token,以 min(1,ptarget/qdraft)\min(1, p_{\text{target}}/q_{\text{draft}}) 的概率接受。如果被拒绝,则从 ptargetp_{\text{target}} 的残差分布中重新采样。数学上可以证明,这样产生的 Token 的分布恰好等于 ptargetp_{\text{target}},即与标准自回归解码完全一致。因此投机解码是零质量损失的加速方法。

7. 投机解码的变体

变体 方法 特点
Self-Speculative 同一模型少层推理作为 Draft 无需额外模型
Medusa 多头预测 无需 Draft 模型
Eagle 特征级 Draft 更高接受率
Lookahead Decoding N-gram 缓存 利用历史信息
Staged Speculative 多级 Draft 模型 层次化加速
Retrieval-based 检索候选 Token 非神经方法

Medusa 多头预测

在 Target Model 之上添加多个预测头,每个头独立预测未来 Token:

Head 0: 预测 next token (t+1)
Head 1: 预测 t+2
Head 2: 预测 t+3
...

无需额外 Draft 模型,但需要额外的训练。

Self-Speculative Decoding

使用同一模型的前 N 层作为 Draft Model,后 N 层作为 Verifier:

Layer 1-16:  快速 Draft(Early Exit)
Layer 17-32: 完整 Verify

无需维护两个模型,但需要模型支持 Early Exit。

8. 推理加速的数学分析

Decode 阶段的瓶颈分析

每个 Decode Step 的耗时主要取决于从 HBM 读取数据的延迟:

tdecodeModel_Params×dtype_size+KV_Cache_SizeHBM_Bandwidtht_{\text{decode}} \approx \frac{\text{Model\_Params} \times \text{dtype\_size} + \text{KV\_Cache\_Size}}{\text{HBM\_Bandwidth}}

以 LLaMA-3-8B (FP16) 为例:

  • 模型权重:8×109×2=168 \times 10^9 \times 2 = 16 GB
  • HBM 带宽(A100):2 TB/s
  • 理论最短 Decode 时间:16 GB2 TB/s=8\frac{16 \text{ GB}}{2 \text{ TB/s}} = 8 ms/Token
  • 实际约 15-20 ms/Token(考虑 KV Cache 读取和其他开销)

量化对 Decode 的加速原理

量化不减少 FLOPs(反量化后用相同精度计算),但减少了需要从 HBM 读取的数据量:

tdecode_INT4Model_Params×0.5+KV_Cache_SizeHBM_Bandwidtht_{\text{decode\_INT4}} \approx \frac{\text{Model\_Params} \times 0.5 + \text{KV\_Cache\_Size}}{\text{HBM\_Bandwidth}}

INT4 量化后模型权重只需读取 2 GB(原 16 GB),Decode 速度可提升 2-4×。

QUESTION 面试题:为什么 4-bit 量化后推理速度可能比 FP16 还快? LLM Decode 阶段是访存密集型(Memory-bound),瓶颈不是计算 FLOPs,而是从 HBM 读取模型权重的带宽。INT4 量化将模型权重从 16 GB 缩减到 2 GB,HBM 读取量减少到 1/8。虽然需要额外的反量化计算,但这部分计算在 SRAM 中完成,远快于 HBM I/O 节省的时间。因此,4-bit 推理在 Decode 阶段可以比 FP16 快 2-4×。

9. 实现支持

引擎 投机解码 Flash Attention 量化推理
vLLM 支持(多种 Draft 策略) 内置 支持
TGI 支持 内置 支持
TensorRT-LLM 支持 内置 支持
SGLang 支持 内置 支持
PyTorch SDPA N/A 内置 FlashAttention-2 N/A

10. FlashAttention + Speculative Decoding 组合

两者互补:

  • FlashAttention:优化单次注意力计算(降低延迟)
  • 投机解码:减少解码 Step 数量(提高吞吐)
  • 组合使用可获得 3-8× 的推理加速