FlashAttention 与投机解码
1. FlashAttention 概述
问题:标准注意力的内存瓶颈
标准自注意力计算:
需要计算并存储完整的 注意力矩阵( 为序列长度),这对 GPU 内存层次结构极不友好。
GPU 内存层次
| 层级 | 类型 | 容量 | 带宽 |
|---|---|---|---|
| HBM (高带宽内存) | 主显存 | 40-80 GB | 1-2 TB/s |
| SRAM (片上共享内存) | 每个 SM 的缓存 | ~192 KB/SM | ~19 TB/s |
关键洞察:标准注意力的瓶颈不是计算(FLOPs),而是内存 I/O —— 反复在 HBM 中读写巨大的 注意力矩阵。
QUESTION 面试题:FlashAttention 解决的核心问题是什么? 标准注意力需要在 HBM 中分配并读写 大小的中间矩阵(S = QK^T 和 P = softmax(S)),这在长序列下产生巨大的内存 I/O 开销。FlashAttention 的核心是避免在 HBM 中存储完整注意力矩阵,通过分块计算(Tiling)将中间结果保留在高速 SRAM 中,从而将 HBM 访问量从 降低到 (其中 为 SRAM 大小)。
2. FlashAttention 原理
核心思想:分块计算(Tiling)
将 Q、K、V 矩阵切分成能放入 SRAM 的小块(Tile),在片上完成注意力计算,避免在 HBM 中存储完整注意力矩阵。
算法步骤
1. 将 Q, K, V 切分为块 Q_i, K_j, V_j(大小适配 SRAM)
2. For each Q_i:
For each K_j, V_j:
a. 在 SRAM 中计算 S_ij = Q_i × K_j^T / √d
b. 使用 Online Softmax 累积:
- 获取当前块的局部 max 和 sum
- 更新全局的 running max 和 running sum
- 修正之前的累积结果
c. 计算 O_ij = softmax(S_ij) × V_j
d. 累积到输出 O_i
3. 将最终输出 O 写回 HBM
Online Softmax Trick
传统 Softmax 需要两遍扫描(先求 max,再计算),FlashAttention 使用增量式 Softmax:
IO 复杂度对比
| 算法 | HBM 读取 | HBM 写入 |
|---|---|---|
| 标准注意力 | ||
| FlashAttention |
其中 为 SRAM 大小。实际效果:2-4× 加速,10-20× 内存节省。
3. FlashAttention-2
改进点
- 更好的工作划分:减少非矩阵乘法操作(GPU 上 matmul 比 non-matmul 快约 16×)
- 序列长度维度并行化:v1 只在 batch 和 head 维度并行,v2 增加了序列维度
- 优化的线程/warp 分配:减少 warp 间的同步等待
性能
- 在 A100 上达到理论峰值 TFLOPs 的 50-73%
- 比标准注意力快 2-4×
- 已集成到 PyTorch
F.scaled_dot_product_attention
4. FlashAttention-3
面向 Hopper GPU (H100) 的优化
- 异步执行:利用 H100 的 TMA(Tensor Memory Accelerator)异步加载数据
- Warp 特化:部分 warp 专门负责加载(Producer),部分专门负责计算(Consumer)
- FP8 支持:利用 FP8 Tensor Core 获得更高吞吐
- Softmax 与 GEMM 重叠:计算与内存操作流水线化
性能
- 在 H100 上达到理论峰值的 75%+
- 比 FlashAttention-2 在 H100 上快 1.5-2×
QUESTION 面试题:FlashAttention 各版本的核心演进是什么?
版本 核心改进 性能 FlashAttention-1 分块计算 + Online Softmax 2-4× 加速 FlashAttention-2 序列维度并行 + 减少 non-matmul A100 峰值 50-73% FlashAttention-3 Hopper TMA + FP8 + Producer-Consumer H100 峰值 75%+ 核心演进路线:减少 HBM I/O → 更好的并行 → 硬件特化优化。每一代都在提高 GPU 计算单元的利用率。
5. 推理加速技术全景
LLM 推理加速技术可从计算优化、访存优化、系统优化三个维度分类:
计算优化
| 技术 | 原理 | 加速效果 | 适用阶段 |
|---|---|---|---|
| FlashAttention | 分块计算减少 HBM I/O | 2-4× | Prefill + Decode |
| 算子融合 (Kernel Fusion) | 多个 CUDA 算子合并为一个 | 1.2-2× | Prefill + Decode |
| 投机解码 | Draft Model 并行验证多 Token | 1.5-3× | Decode |
| Tensor Core 优化 | 利用 GPU 专用矩阵单元 | 2-8× | Prefill |
访存优化
| 技术 | 原理 | 效果 | 适用阶段 |
|---|---|---|---|
| KV Cache | 缓存历史 KV 避免重复计算 | 数十倍加速 | Decode |
| GQA / MQA | 减少 KV 头数 | KV Cache 减少 4-32× | Decode |
| KV Cache 量化 | FP16→FP8/INT4 | 显存减半至 1/4 | Decode |
| PagedAttention | 分页管理消除碎片 | 利用率 20-40%→>95% | 全阶段 |
系统优化
| 技术 | 原理 | 效果 | 适用场景 |
|---|---|---|---|
| Continuous Batching | 迭代级调度 | 吞吐 2-4× | 在线服务 |
| Chunked Prefill | 长 Prompt 分块处理 | 减少阻塞 | 长上下文 |
| Prefix Caching | 共享前缀 KV Cache | 减少重复计算 | System Prompt |
| CPU Offloading | KV Cache 换出到 CPU | 支持更大并发 | 显存不足 |
QUESTION 面试题:LLM 推理有哪些加速方法?请分类说明。 推理加速可从三个维度分类:
- 计算优化:FlashAttention(分块计算)、算子融合、投机解码、Tensor Core 利用
- 访存优化:KV Cache、GQA/MQA(减少 KV 头数)、KV Cache 量化、PagedAttention
- 系统优化:Continuous Batching、Chunked Prefill、Prefix Caching、CPU Offloading、量化推理
选择的依据:Prefill 瓶颈在计算(选计算优化),Decode 瓶颈在访存(选访存优化),在线服务需高吞吐(选系统优化)。
6. 投机解码(Speculative Decoding)
核心思想
自回归生成的瓶颈是逐 Token 串行生成,每次只产生一个 Token。投机解码通过引入一个小型草稿模型(Draft Model) 来并行验证多个候选 Token,加速生成。
工作流程
Step 1: Draft 阶段
┌─────────────────────────────────────────┐
│ Draft Model (小型,快速) │
│ 自回归生成 K 个候选 Token: [t1, t2, t3, t4, t5] │
└─────────────────────────────────────────┘
↓
Step 2: Verify 阶段
┌─────────────────────────────────────────┐
│ Target Model (大型,目标模型) │
│ 单次前向传播处理所有 K 个 Token │
│ 获取每个位置的概率分布 │
│ 与 Draft Token 逐一比较 │
└─────────────────────────────────────────┘
↓
Step 3: 接受/拒绝
┌─────────────────────────────────────────┐
│ t1: ✓ 接受 (概率匹配) │
│ t2: ✓ 接受 │
│ t3: ✓ 接受 │
│ t4: ✗ 拒绝 → 使用 Target 的分布采样 │
│ t5: 跳过 (因为 t4 已拒绝) │
│ 结果: 接受了 3 个 Token + 1 个新 Token = 4 │
└─────────────────────────────────────────┘
接受概率
Token 被接受的概率:
当 Draft 模型质量好时(),大多数 Token 都会被接受。
数学保证
投机解码的输出分布与标准自回归解码完全相同。接受-拒绝机制基于修改的拒绝采样,确保零质量损失。
加速效果
| 草稿 Token 数 K | 接受率 | 实际加速比 |
|---|---|---|
| 5 | 80% | ~2.5× |
| 5 | 60% | ~1.8× |
| 10 | 80% | ~3.5× |
| 10 | 60% | ~2.2× |
典型实践:1.5×-3× 加速(单批次推理)
QUESTION 面试题:投机解码为什么能保证输出质量不变? 投机解码使用修改的拒绝采样(Modified Rejection Sampling):对于 Draft 模型生成的每个 Token,以 的概率接受。如果被拒绝,则从 的残差分布中重新采样。数学上可以证明,这样产生的 Token 的分布恰好等于 ,即与标准自回归解码完全一致。因此投机解码是零质量损失的加速方法。
7. 投机解码的变体
| 变体 | 方法 | 特点 |
|---|---|---|
| Self-Speculative | 同一模型少层推理作为 Draft | 无需额外模型 |
| Medusa | 多头预测 | 无需 Draft 模型 |
| Eagle | 特征级 Draft | 更高接受率 |
| Lookahead Decoding | N-gram 缓存 | 利用历史信息 |
| Staged Speculative | 多级 Draft 模型 | 层次化加速 |
| Retrieval-based | 检索候选 Token | 非神经方法 |
Medusa 多头预测
在 Target Model 之上添加多个预测头,每个头独立预测未来 Token:
Head 0: 预测 next token (t+1)
Head 1: 预测 t+2
Head 2: 预测 t+3
...
无需额外 Draft 模型,但需要额外的训练。
Self-Speculative Decoding
使用同一模型的前 N 层作为 Draft Model,后 N 层作为 Verifier:
Layer 1-16: 快速 Draft(Early Exit)
Layer 17-32: 完整 Verify
无需维护两个模型,但需要模型支持 Early Exit。
8. 推理加速的数学分析
Decode 阶段的瓶颈分析
每个 Decode Step 的耗时主要取决于从 HBM 读取数据的延迟:
以 LLaMA-3-8B (FP16) 为例:
- 模型权重: GB
- HBM 带宽(A100):2 TB/s
- 理论最短 Decode 时间: ms/Token
- 实际约 15-20 ms/Token(考虑 KV Cache 读取和其他开销)
量化对 Decode 的加速原理
量化不减少 FLOPs(反量化后用相同精度计算),但减少了需要从 HBM 读取的数据量:
INT4 量化后模型权重只需读取 2 GB(原 16 GB),Decode 速度可提升 2-4×。
QUESTION 面试题:为什么 4-bit 量化后推理速度可能比 FP16 还快? LLM Decode 阶段是访存密集型(Memory-bound),瓶颈不是计算 FLOPs,而是从 HBM 读取模型权重的带宽。INT4 量化将模型权重从 16 GB 缩减到 2 GB,HBM 读取量减少到 1/8。虽然需要额外的反量化计算,但这部分计算在 SRAM 中完成,远快于 HBM I/O 节省的时间。因此,4-bit 推理在 Decode 阶段可以比 FP16 快 2-4×。
9. 实现支持
| 引擎 | 投机解码 | Flash Attention | 量化推理 |
|---|---|---|---|
| vLLM | 支持(多种 Draft 策略) | 内置 | 支持 |
| TGI | 支持 | 内置 | 支持 |
| TensorRT-LLM | 支持 | 内置 | 支持 |
| SGLang | 支持 | 内置 | 支持 |
| PyTorch SDPA | N/A | 内置 FlashAttention-2 | N/A |
10. FlashAttention + Speculative Decoding 组合
两者互补:
- FlashAttention:优化单次注意力计算(降低延迟)
- 投机解码:减少解码 Step 数量(提高吞吐)
- 组合使用可获得 3-8× 的推理加速