2026年4月19日

KV Cache 基础:原理与显存计算

LLM 推理采用**自回归生成**:每次生成一个 Token,每个新 Token 都需要对之前所有 Token 做 Attention 计算。

知识库大模型推理与系统

KV Cache 基础:原理与显存计算

1. KV Cache 是什么

问题背景

LLM 推理采用自回归生成:每次生成一个 Token,每个新 Token 都需要对之前所有 Token 做 Attention 计算。

没有 KV Cache 的情况

生成 Token 1: 计算 Q1, K1, V1 → Attention(Q1, K1, V1)
生成 Token 2: 计算 Q2, K2, V2 + 重新计算 K1, V1 → Attention(Q2, [K1,K2], [V1,V2])
生成 Token 3: 计算 Q3, K3, V3 + 重新计算 K1,K2, V1,V2 → ...

每个步骤都要重新计算之前所有 Token 的 Key 和 Value,造成巨大浪费。

有 KV Cache 的情况

生成 Token 1: 计算 Q1, K1, V1 → 缓存 K1, V1 → Attention
生成 Token 2: 计算 Q2, K2, V2 → 追加 K2, V2 到缓存 → Attention
生成 Token 3: 计算 Q3, K3, V3 → 追加 K3, V3 到缓存 → Attention

只需计算当前 Token 的 K、V,之前的直接从缓存读取。

QUESTION 面试题:为什么需要 KV Cache?能不能不要? 不能不要。自回归生成中,每个新 Token 都需要与之前所有 Token 做注意力计算。没有 KV Cache 时,每生成一个新 Token 都需要重新计算前面所有 Token 的 K、V 向量,计算复杂度为 O(N2)O(N^2),这是严重的冗余计算。KV Cache 将已计算的 K、V 缓存复用,将每个 Decode Step 的计算从 O(N)O(N) 降为 O(1)O(1)(只计算当前 Token),整体推理速度可提升数倍到数十倍。

QUESTION 面试题:KV Cache 的本质是什么? KV Cache 本质上是空间换时间的策略。它将历史 Token 的 Key 和 Value 向量缓存下来,避免重复计算。代价是占用额外显存,对于长序列和大 Batch 场景,KV Cache 可能占用数十 GB 甚至更多显存。

2. 工作原理

Prefill 阶段(首次前向传播)

处理完整的 Prompt,一次性计算所有 Token 的 K、V 并缓存:

输入: "What is the capital of France?"
→ 计算 K1...K8, V1...V8 → 全部缓存
→ 输出第一个 Token: "The"

Prefill 阶段是计算密集型(Compute-bound),因为需要并行处理所有 Prompt Token,涉及大量矩阵乘法。

Decode 阶段(逐 Token 生成)

每次只处理一个新 Token:

Step 1: 输入 "The" → 计算 K_new, V_new → 追加到缓存 → 输出 "capital"
Step 2: 输入 "capital" → 计算 K_new, V_new → 追加到缓存 → 输出 "of"
Step 3: 输入 "of" → 计算 K_new, V_new → 追加到缓存 → 输出 "France"
...

Decode 阶段是访存密集型(Memory-bound),每次只计算 1 个 Token 的 Q、K、V,但需要读取全部历史 KV Cache 和模型权重。

注意力计算公式(带 KV Cache)

Attention(Qt,K1:t,V1:t)=softmax(QtK1:tTdk)V1:t\text{Attention}(Q_t, K_{1:t}, V_{1:t}) = \text{softmax}\left(\frac{Q_t K_{1:t}^T}{\sqrt{d_k}}\right) V_{1:t}

QUESTION 面试题:Prefill 和 Decode 阶段分别是什么瓶颈?

  • Prefill 阶段:计算密集(Compute-bound),因为需要并行处理所有 Prompt Token,瓶颈在 GPU 的计算能力(FLOPs)
  • Decode 阶段:访存密集(Memory-bound),每次只处理 1 个 Token 但需要读取全部 KV Cache 和模型权重,瓶颈在 GPU 的显存带宽(HBM bandwidth)

这意味着优化两个阶段的方向不同:Prefill 优化重在并行计算效率(如 FlashAttention),Decode 优化重在减少显存访问量(如 GQA、KV Cache 量化、投机解码)。

3. 显存计算公式

基础公式

KV Cache Size=2×nlayers×nkv_heads×dhead×seq_len×dtype_size\text{KV Cache Size} = 2 \times n_{\text{layers}} \times n_{\text{kv\_heads}} \times d_{\text{head}} \times \text{seq\_len} \times \text{dtype\_size}

其中:

  • 2:Key 和 Value 两份缓存
  • n_layers:Transformer 层数
  • n_kv_heads:KV 头数(GQA 模型中少于 Query 头数)
  • d_head:每个头的维度
  • seq_len:序列长度(Prompt + 生成)
  • dtype_size:数据类型大小(FP16=2, FP8=1, FP32=4)

简化公式(以字节为单位)

KV Cache (bytes)=2×nlayers×hidden_dimkv×seq_len×dtype_size\text{KV Cache (bytes)} = 2 \times n_{\text{layers}} \times \text{hidden\_dim}_{kv} \times \text{seq\_len} \times \text{dtype\_size}

QUESTION 面试题:请估算 LLaMA-2-7B 模型在 Batch=32、序列长度 4096 时的 KV Cache 显存 代入公式:KV Cache=2×32×32×128×4096×2=2,147,483,648\text{KV Cache} = 2 \times 32 \times 32 \times 128 \times 4096 \times 2 = 2,147,483,648 bytes 2.0\approx 2.0 GB/序列。Batch=32 时总 KV Cache =64= 64 GB。加上模型权重约 14 GB,总需求超过 78 GB,需要多卡或量化部署。

4. 实际计算示例

示例 1:LLaMA-2-7B

参数
num_layers 32
num_kv_heads 32
head_dim 128
dtype FP16 (2 bytes)
seq_len 4096
KV Cache = 2 × 32 × 32 × 128 × 4096 × 2
         = 2 × 32 × 4096 × 4096 × 2
         = 2,147,483,648 bytes
         ≈ 2.0 GB (per sequence)

Batch Size = 32 时:64 GB

示例 2:LLaMA-2-70B(使用 GQA)

参数
num_layers 80
num_kv_heads 8 (GQA)
head_dim 128
dtype FP16 (2 bytes)
seq_len 4096
KV Cache = 2 × 80 × 8 × 128 × 4096 × 2
         = 2 × 80 × 1024 × 4096 × 2
         = 1,342,177,280 bytes
         ≈ 1.25 GB (per sequence)

注意:虽然 70B 模型远大于 7B,但 GQA 使 KV Cache 反而更小。

示例 3:长上下文场景(128K 序列)

以 LLaMA-3-8B 为例:

参数
num_layers 32
num_kv_heads 8 (GQA)
head_dim 128
dtype FP16 (2 bytes)
seq_len 131072 (128K)
KV Cache = 2 × 32 × 8 × 128 × 131072 × 2
         ≈ 17.2 GB (per sequence)

单个请求就需要 17 GB!这就是长上下文推理的挑战。

5. GQA / MQA 对 KV Cache 的影响

注意力类型 KV 头数 KV Cache 大小 代表模型
MHA (Multi-Head) = Query 头数 基准 GPT-3, LLaMA-2-7B
GQA (Grouped-Query) < Query 头数 减少 4-8× LLaMA-2-70B, LLaMA-3
MQA (Multi-Query) 1 减少最多 PaLM, Falcon

GQA 的意义:在推理效率和模型质量之间取得最佳平衡。

QUESTION 面试题:MHA、GQA、MQA 的区别及对推理的影响?

  • MHA:每个 Query 头都有独立的 K、V 头,KV Cache 最大,但精度最好
  • GQA:多个 Query 头共享一组 K、V 头(如 LLaMA-3 中 32 个 Q 头共享 8 个 KV 头),KV Cache 减少 4×,精度损失极小
  • MQA:所有 Query 头共享 1 个 K、V 头,KV Cache 最小,但精度损失较大

GQA 被认为是最佳平衡点,LLaMA-2-70B、LLaMA-3、Mistral 等主流模型均采用 GQA。GQA 不仅减少 KV Cache 显存,更重要的是减少了 Decode 阶段从 HBM 读取 KV 的数据量,提升推理速度。

6. KV Cache 优化技术

6.1 PagedAttention(vLLM)

将 KV Cache 分成固定大小的 Block(类似操作系统虚拟内存分页),按需分配:

  • 消除内部碎片和外部碎片
  • 支持跨序列共享前缀(Prefix Caching)
  • KV Cache 利用率从 20-40% 提升到接近 100%
  • 详见 vLLM

6.2 KV Cache 量化

将 KV Cache 从 FP16 量化到 FP8 或 INT4:

  • FP8:显存减半,精度损失极小(推荐)
  • INT4:显存降至 1/4,有一定精度损失
  • KIVI(2-bit KV Cache):极致压缩,研究阶段

6.3 滑动窗口注意力

只缓存最近 W 个 Token 的 KV(如 Mistral 的 W=4096):

  • 固定 KV Cache 大小,与序列长度无关
  • 配合 Attention Sink 保留初始 Token
  • 详见 系统优化

6.4 Token 淘汰策略

智能移除不重要的 KV 条目:

  • H2O (Heavy-Hitter Oracle):保留注意力分数累积最高的 Token
  • Scissorhands:保留"关键"Token,移除冗余 Token
  • 基于注意力分数的淘汰,在保持精度的同时显著压缩 KV Cache

6.5 跨层 KV 共享

某些架构(如 MiniCPM、Adaptive-Layer KV)共享相邻层的 KV Cache:

  • 减少总 KV 存储量
  • 适用于层间 KV 差异小的模型

7. Prefill vs Decode 的 KV Cache 行为

阶段 计算特点 内存特点 瓶颈
Prefill 计算密集(并行处理所有 Prompt Token) 一次性写入大量 KV 计算
Decode 访存密集(每次处理 1 Token,读取全部 KV) 逐步追加少量 KV 内存带宽

Decode 阶段是 KV Cache 的主要瓶颈:每个 Step 都需要从 HBM 读取完整 KV Cache,内存带宽成为限制因素。

QUESTION 面试题:如何减少 KV Cache 的显存占用? 多种方法可组合使用:

  1. GQA/MQA:减少 KV 头数,从结构上压缩(如 LLaMA-3 用 8 KV heads 替代 32)
  2. KV Cache 量化:FP16→FP8 可减半显存,精度损失极小
  3. PagedAttention:消除碎片,利用率从 20-40% 提升到 >95%
  4. 滑动窗口:固定 KV Cache 为 O(W)O(W),但会丢失长距离信息
  5. Token 淘汰:H2O 等策略移除不重要的 KV 条目
  6. 跨层共享:相邻层共享 KV,减少总存储量
  7. 降低序列长度:通过 Prompt 压缩、RAG 替代长上下文等

8. KV Cache 与推理加速的关系

KV Cache 是推理加速的核心枢纽,几乎所有推理优化技术都直接或间接作用于 KV Cache:

推理加速技术图谱(KV Cache 为中心)

                    ┌─ GQA / MQA ──── 减少 KV 头数
                    ├─ KV Cache 量化 ── 降低单条目大小
    KV Cache ───────├─ PagedAttention ─ 消除碎片,高效管理
    显存/带宽优化    ├─ 滑动窗口 ────── 固定 KV 大小
                    ├─ Token 淘汰 ──── 智能压缩
                    └─ Prefix Caching ─ 共享公共前缀

                    ┌─ FlashAttention ── 高效注意力计算内核
    计算加速 ───────├─ 投机解码 ──────── 减少 Decode 步数
                    └─ 算子融合 ──────── 减少内存读写次数

                    ┌─ Continuous Batching ─ 最大化 GPU 利用
    系统优化 ───────├─ Chunked Prefill ──── 避免长 Prefill 阻塞
                    └─ Tensor Parallelism ── 多 GPU 并行

QUESTION 面试题:推理速度优化主要有哪些方向? 主要分为三大方向:

  1. KV Cache 优化(显存/带宽):GQA、量化、PagedAttention、滑动窗口、Token 淘汰、Prefix Caching
  2. 计算优化(FLOPs):FlashAttention(分块计算减少 HBM 访问)、投机解码(并行验证多 Token)、算子融合
  3. 系统级优化(吞吐):Continuous Batching、Chunked Prefill、Tensor/Pipeline Parallelism

其中 Decode 阶段主要受限于显存带宽,因此 KV Cache 相关优化(方向 1)对 Decode 加速最为关键。