KV Cache 基础:原理与显存计算
1. KV Cache 是什么
问题背景
LLM 推理采用自回归生成:每次生成一个 Token,每个新 Token 都需要对之前所有 Token 做 Attention 计算。
没有 KV Cache 的情况:
生成 Token 1: 计算 Q1, K1, V1 → Attention(Q1, K1, V1)
生成 Token 2: 计算 Q2, K2, V2 + 重新计算 K1, V1 → Attention(Q2, [K1,K2], [V1,V2])
生成 Token 3: 计算 Q3, K3, V3 + 重新计算 K1,K2, V1,V2 → ...
每个步骤都要重新计算之前所有 Token 的 Key 和 Value,造成巨大浪费。
有 KV Cache 的情况:
生成 Token 1: 计算 Q1, K1, V1 → 缓存 K1, V1 → Attention
生成 Token 2: 计算 Q2, K2, V2 → 追加 K2, V2 到缓存 → Attention
生成 Token 3: 计算 Q3, K3, V3 → 追加 K3, V3 到缓存 → Attention
只需计算当前 Token 的 K、V,之前的直接从缓存读取。
QUESTION 面试题:为什么需要 KV Cache?能不能不要? 不能不要。自回归生成中,每个新 Token 都需要与之前所有 Token 做注意力计算。没有 KV Cache 时,每生成一个新 Token 都需要重新计算前面所有 Token 的 K、V 向量,计算复杂度为 ,这是严重的冗余计算。KV Cache 将已计算的 K、V 缓存复用,将每个 Decode Step 的计算从 降为 (只计算当前 Token),整体推理速度可提升数倍到数十倍。
QUESTION 面试题:KV Cache 的本质是什么? KV Cache 本质上是空间换时间的策略。它将历史 Token 的 Key 和 Value 向量缓存下来,避免重复计算。代价是占用额外显存,对于长序列和大 Batch 场景,KV Cache 可能占用数十 GB 甚至更多显存。
2. 工作原理
Prefill 阶段(首次前向传播)
处理完整的 Prompt,一次性计算所有 Token 的 K、V 并缓存:
输入: "What is the capital of France?"
→ 计算 K1...K8, V1...V8 → 全部缓存
→ 输出第一个 Token: "The"
Prefill 阶段是计算密集型(Compute-bound),因为需要并行处理所有 Prompt Token,涉及大量矩阵乘法。
Decode 阶段(逐 Token 生成)
每次只处理一个新 Token:
Step 1: 输入 "The" → 计算 K_new, V_new → 追加到缓存 → 输出 "capital"
Step 2: 输入 "capital" → 计算 K_new, V_new → 追加到缓存 → 输出 "of"
Step 3: 输入 "of" → 计算 K_new, V_new → 追加到缓存 → 输出 "France"
...
Decode 阶段是访存密集型(Memory-bound),每次只计算 1 个 Token 的 Q、K、V,但需要读取全部历史 KV Cache 和模型权重。
注意力计算公式(带 KV Cache)
QUESTION 面试题:Prefill 和 Decode 阶段分别是什么瓶颈?
- Prefill 阶段:计算密集(Compute-bound),因为需要并行处理所有 Prompt Token,瓶颈在 GPU 的计算能力(FLOPs)
- Decode 阶段:访存密集(Memory-bound),每次只处理 1 个 Token 但需要读取全部 KV Cache 和模型权重,瓶颈在 GPU 的显存带宽(HBM bandwidth)
这意味着优化两个阶段的方向不同:Prefill 优化重在并行计算效率(如 FlashAttention),Decode 优化重在减少显存访问量(如 GQA、KV Cache 量化、投机解码)。
3. 显存计算公式
基础公式
其中:
2:Key 和 Value 两份缓存n_layers:Transformer 层数n_kv_heads:KV 头数(GQA 模型中少于 Query 头数)d_head:每个头的维度seq_len:序列长度(Prompt + 生成)dtype_size:数据类型大小(FP16=2, FP8=1, FP32=4)
简化公式(以字节为单位)
QUESTION 面试题:请估算 LLaMA-2-7B 模型在 Batch=32、序列长度 4096 时的 KV Cache 显存 代入公式: bytes GB/序列。Batch=32 时总 KV Cache GB。加上模型权重约 14 GB,总需求超过 78 GB,需要多卡或量化部署。
4. 实际计算示例
示例 1:LLaMA-2-7B
| 参数 | 值 |
|---|---|
| num_layers | 32 |
| num_kv_heads | 32 |
| head_dim | 128 |
| dtype | FP16 (2 bytes) |
| seq_len | 4096 |
KV Cache = 2 × 32 × 32 × 128 × 4096 × 2
= 2 × 32 × 4096 × 4096 × 2
= 2,147,483,648 bytes
≈ 2.0 GB (per sequence)
Batch Size = 32 时:64 GB
示例 2:LLaMA-2-70B(使用 GQA)
| 参数 | 值 |
|---|---|
| num_layers | 80 |
| num_kv_heads | 8 (GQA) |
| head_dim | 128 |
| dtype | FP16 (2 bytes) |
| seq_len | 4096 |
KV Cache = 2 × 80 × 8 × 128 × 4096 × 2
= 2 × 80 × 1024 × 4096 × 2
= 1,342,177,280 bytes
≈ 1.25 GB (per sequence)
注意:虽然 70B 模型远大于 7B,但 GQA 使 KV Cache 反而更小。
示例 3:长上下文场景(128K 序列)
以 LLaMA-3-8B 为例:
| 参数 | 值 |
|---|---|
| num_layers | 32 |
| num_kv_heads | 8 (GQA) |
| head_dim | 128 |
| dtype | FP16 (2 bytes) |
| seq_len | 131072 (128K) |
KV Cache = 2 × 32 × 8 × 128 × 131072 × 2
≈ 17.2 GB (per sequence)
单个请求就需要 17 GB!这就是长上下文推理的挑战。
5. GQA / MQA 对 KV Cache 的影响
| 注意力类型 | KV 头数 | KV Cache 大小 | 代表模型 |
|---|---|---|---|
| MHA (Multi-Head) | = Query 头数 | 基准 | GPT-3, LLaMA-2-7B |
| GQA (Grouped-Query) | < Query 头数 | 减少 4-8× | LLaMA-2-70B, LLaMA-3 |
| MQA (Multi-Query) | 1 | 减少最多 | PaLM, Falcon |
GQA 的意义:在推理效率和模型质量之间取得最佳平衡。
QUESTION 面试题:MHA、GQA、MQA 的区别及对推理的影响?
- MHA:每个 Query 头都有独立的 K、V 头,KV Cache 最大,但精度最好
- GQA:多个 Query 头共享一组 K、V 头(如 LLaMA-3 中 32 个 Q 头共享 8 个 KV 头),KV Cache 减少 4×,精度损失极小
- MQA:所有 Query 头共享 1 个 K、V 头,KV Cache 最小,但精度损失较大
GQA 被认为是最佳平衡点,LLaMA-2-70B、LLaMA-3、Mistral 等主流模型均采用 GQA。GQA 不仅减少 KV Cache 显存,更重要的是减少了 Decode 阶段从 HBM 读取 KV 的数据量,提升推理速度。
6. KV Cache 优化技术
6.1 PagedAttention(vLLM)
将 KV Cache 分成固定大小的 Block(类似操作系统虚拟内存分页),按需分配:
- 消除内部碎片和外部碎片
- 支持跨序列共享前缀(Prefix Caching)
- KV Cache 利用率从 20-40% 提升到接近 100%
- 详见 vLLM
6.2 KV Cache 量化
将 KV Cache 从 FP16 量化到 FP8 或 INT4:
- FP8:显存减半,精度损失极小(推荐)
- INT4:显存降至 1/4,有一定精度损失
- KIVI(2-bit KV Cache):极致压缩,研究阶段
6.3 滑动窗口注意力
只缓存最近 W 个 Token 的 KV(如 Mistral 的 W=4096):
- 固定 KV Cache 大小,与序列长度无关
- 配合 Attention Sink 保留初始 Token
- 详见 系统优化
6.4 Token 淘汰策略
智能移除不重要的 KV 条目:
- H2O (Heavy-Hitter Oracle):保留注意力分数累积最高的 Token
- Scissorhands:保留"关键"Token,移除冗余 Token
- 基于注意力分数的淘汰,在保持精度的同时显著压缩 KV Cache
6.5 跨层 KV 共享
某些架构(如 MiniCPM、Adaptive-Layer KV)共享相邻层的 KV Cache:
- 减少总 KV 存储量
- 适用于层间 KV 差异小的模型
7. Prefill vs Decode 的 KV Cache 行为
| 阶段 | 计算特点 | 内存特点 | 瓶颈 |
|---|---|---|---|
| Prefill | 计算密集(并行处理所有 Prompt Token) | 一次性写入大量 KV | 计算 |
| Decode | 访存密集(每次处理 1 Token,读取全部 KV) | 逐步追加少量 KV | 内存带宽 |
Decode 阶段是 KV Cache 的主要瓶颈:每个 Step 都需要从 HBM 读取完整 KV Cache,内存带宽成为限制因素。
QUESTION 面试题:如何减少 KV Cache 的显存占用? 多种方法可组合使用:
- GQA/MQA:减少 KV 头数,从结构上压缩(如 LLaMA-3 用 8 KV heads 替代 32)
- KV Cache 量化:FP16→FP8 可减半显存,精度损失极小
- PagedAttention:消除碎片,利用率从 20-40% 提升到 >95%
- 滑动窗口:固定 KV Cache 为 ,但会丢失长距离信息
- Token 淘汰:H2O 等策略移除不重要的 KV 条目
- 跨层共享:相邻层共享 KV,减少总存储量
- 降低序列长度:通过 Prompt 压缩、RAG 替代长上下文等
8. KV Cache 与推理加速的关系
KV Cache 是推理加速的核心枢纽,几乎所有推理优化技术都直接或间接作用于 KV Cache:
推理加速技术图谱(KV Cache 为中心)
┌─ GQA / MQA ──── 减少 KV 头数
├─ KV Cache 量化 ── 降低单条目大小
KV Cache ───────├─ PagedAttention ─ 消除碎片,高效管理
显存/带宽优化 ├─ 滑动窗口 ────── 固定 KV 大小
├─ Token 淘汰 ──── 智能压缩
└─ Prefix Caching ─ 共享公共前缀
┌─ FlashAttention ── 高效注意力计算内核
计算加速 ───────├─ 投机解码 ──────── 减少 Decode 步数
└─ 算子融合 ──────── 减少内存读写次数
┌─ Continuous Batching ─ 最大化 GPU 利用
系统优化 ───────├─ Chunked Prefill ──── 避免长 Prefill 阻塞
└─ Tensor Parallelism ── 多 GPU 并行
QUESTION 面试题:推理速度优化主要有哪些方向? 主要分为三大方向:
- KV Cache 优化(显存/带宽):GQA、量化、PagedAttention、滑动窗口、Token 淘汰、Prefix Caching
- 计算优化(FLOPs):FlashAttention(分块计算减少 HBM 访问)、投机解码(并行验证多 Token)、算子融合
- 系统级优化(吞吐):Continuous Batching、Chunked Prefill、Tensor/Pipeline Parallelism
其中 Decode 阶段主要受限于显存带宽,因此 KV Cache 相关优化(方向 1)对 Decode 加速最为关键。