2025年5月12日

Attention 与 Self-Attention 机制

Attention 是一种让模型在处理序列时,能够动态地"聚焦"于输入中最相关部分的机制;Self-Attention 是 Attention 的特例,其中 Query、Key、Value 全部来自同一个序列,使序列中的每个位置都能与所有其...

知识库大模型基础原理llmmodeltransformerattention

先说结论

Attention 是一种让模型在处理序列时,能够动态地"聚焦"于输入中最相关部分的机制;Self-Attention 是 Attention 的特例,其中 Query、Key、Value 全部来自同一个序列,使序列中的每个位置都能与所有其他位置交互。

为什么我会单独记这一篇

在 Transformer 出现之前,序列建模主要依赖 RNN/LSTM 和 CNN。RNN 存在三个根本性缺陷:

  1. 长距离依赖衰减:信息必须沿时间步依次传递,梯度在反向传播中不断衰减,导致长序列中远端 token 之间的关联难以学习。
  2. 串行计算瓶颈:每个时间步依赖前一步的输出,无法并行化,训练效率低下。
  3. 固定大小的上下文压缩:无论输入多长,RNN 都要将其压缩到一个隐状态向量中,信息瓶颈严重。

Attention 机制直接解决了这三个问题:任意两个位置之间的信息传递路径长度为 O(1),计算可以完全并行化,每个位置都可以直接访问所有其他位置的完整信息,不存在压缩瓶颈。

先把核心脉络捋清楚

Scaled Dot-Product Attention(缩放点积注意力)

公式:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

各要素的含义:

  • Query (Q):当前位置发出的"查询"向量,表示"我在找什么信息"。
  • Key (K):每个位置的"键"向量,表示"我能提供什么信息"。
  • Value (V):每个位置的"值"向量,表示"我实际携带的信息内容"。
  • QK^T:计算 Query 与所有 Key 的相似度(点积),得到注意力分数矩阵。
  • 除以 sqrt(d_k):缩放因子,防止点积值过大导致 softmax 进入梯度极小的饱和区域。这是"缩放"点积注意力与普通点积注意力的关键区别。
  • softmax:将分数归一化为概率分布,每个位置对其他位置的注意力权重之和为 1。
  • 乘以 V:用注意力权重对 Value 加权求和,得到每个位置的最终输出。

直觉上,这就像在图书馆查资料:你带着一个问题(Query),浏览所有书架标签(Key),找到最相关的几本书,然后从中提取信息(Value)。

Multi-Head Attention(多头注意力)

公式:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O

其中 head_i = Attention(Q * W_i^Q, K * W_i^K, V * W_i^V)

核心思想:不使用单一的注意力函数,而是将 Q、K、V 分别投影到 h 个不同的低维子空间中,在每个子空间独立计算注意力,然后将结果拼接并线性投影回原始维度。

原始 Transformer 使用 h=8 个头,每个头的维度 d_k = d_v = 64,总维度 d_model = 512。

为什么要多头?单个注意力头倾向于关注一种特定类型的关系(如语法关系、语义关联、位置相近等)。多头允许模型同时捕捉多种不同类型的依赖关系,每个头学习不同的"注意力模式"。

Self-Attention(自注意力)

当 Q、K、V 全部来自同一个序列时,就是 Self-Attention。这是 Transformer Encoder 的核心操作。每个 token 都可以直接"看到"序列中的所有其他 token,无论距离远近。

Cross-Attention(交叉注意力)

Q 来自一个序列(如解码器),K 和 V 来自另一个序列(如编码器输出)。这是 Transformer Decoder 中连接编码器和解码器的桥梁。

Masked Self-Attention(掩码自注意力)

在自回归生成中,当前位置不能看到未来位置的信息。通过在 softmax 之前将未来位置的注意力分数设为负无穷(即掩码),确保模型只能关注当前及之前的 token。

原理拆开看

缩放的必要性

假设 Q 和 K 的每个分量独立且均值为 0、方差为 1。那么点积 Q·K 的方差为 d_k。当 d_k 较大时(如 64 或 128),点积值的量级会很大,导致 softmax 输出接近 one-hot 分布,梯度趋近于零。除以 sqrt(d_k) 将方差归一化为 1,使 softmax 保持合理的梯度。

这不是一个微不足道的细节。如果忘记缩放,在 d_k 较大时模型训练会变得极其困难,损失函数可能不收敛。

计算复杂度

标准 Self-Attention 的时间和空间复杂度为 O(n^2 * d),其中 n 是序列长度,d 是模型维度。这是因为需要计算 n x n 的注意力矩阵。这是 Transformer 处理超长序列时的主要瓶颈。

对于序列长度 n 远大于维度 d 的场景(如长文档),注意力计算成为瓶颈。这也是 Flash Attention、Sparse Attention、Sliding Window Attention 等优化技术出现的动机。

注意力的几何直觉

可以将注意力理解为在高维空间中对信息进行软性路由。softmax 产生的权重让每个位置从其他位置"选择性地聚合"信息。与 CNN 的局部感受野不同,注意力的感受野是全局的;与 RNN 的顺序传播不同,注意力的信息传递是直接的。

放到工程里怎么落

标准实现步骤

  1. 输入 token 经过嵌入层得到 X (shape: [batch, seq_len, d_model])。
  2. 通过三个线性层分别计算 Q = X @ W_Q, K = X @ W_K, V = X @ W_V。
  3. 计算注意力分数 scores = Q @ K^T / sqrt(d_k)。
  4. (可选)应用注意力掩码(mask)。
  5. 应用 softmax 得到注意力权重。
  6. (可选)应用 dropout。
  7. 加权求和 output = weights @ V。
  8. 多头情况下,拼接所有头的输出并通过 W_O 线性投影。

Flash Attention

标准实现需要将完整的 n x n 注意力矩阵存储在 GPU 高带宽内存(HBM)中,内存开销 O(n^2)。Flash Attention 通过分块计算(tiling)和在线 softmax(online softmax)技术,将内存复杂度降低到 O(n),同时计算精确结果(非近似)。Flash Attention-2(2023)和 Flash Attention-3(2024)进一步优化了并行度和硬件利用率。

Multi-Query Attention (MQA) 与 Grouped-Query Attention (GQA)

推理时的 KV Cache 是内存瓶颈。MQA 让所有 Query 头共享一组 Key/Value 头(即 K、V 只有 1 个头),大幅减少 KV Cache 大小。GQA 是 MHA 和 MQA 的折中,使用 G 组 KV 头(1 < G < h)。LLaMA 2/3、Mistral 等现代模型普遍采用 GQA。

与相邻概念的区别

  • Attention vs Self-Attention:Attention 是通用概念,Q/K/V 可以来自不同来源;Self-Attention 是 Q/K/V 来自同一来源的特例。
  • Self-Attention vs Cross-Attention:Cross-Attention 的 Q 和 K/V 来自不同序列,用于连接编码器和解码器。
  • Attention vs Convolution:卷积是局部、固定的权重;注意力是全局、动态的权重。
  • Attention vs Pooling:池化是静态聚合;注意力是加权聚合,权重由数据决定。

设计时真正要权衡什么

设计选择 优势 代价
点积注意力 vs 加性注意力 计算更快,可利用矩阵乘法优化 在极小维度上可能不如加性注意力
多头 vs 单头 捕捉多种关系模式 增加参数量和计算量
全局注意力 vs 稀疏注意力 完整的长距离依赖 O(n^2) 复杂度
MHA vs GQA vs MQA 推理效率递增 模型容量/质量可能递减

容易踩的坑

  1. 忘记缩放因子 sqrt(d_k):模型训练不收敛,或需要极小的学习率才能勉强工作。
  2. 掩码实现错误:自回归生成中出现信息泄露,训练时 loss 正常但推理时质量极差。
  3. 注意力矩阵数值溢出/下溢:在 fp16 混合精度训练中,softmax 之前的分数可能超出 fp16 表示范围。解决方案是使用 log-sum-exp 技巧或 Flash Attention。
  4. KV Cache 管理不当:推理时内存爆炸或 token 重复/遗漏。
  5. 多头退化为单头:如果初始化或正则化不当,多个头可能学习到相同模式,浪费模型容量。

工程落地时我会怎么做

  1. 始终使用 Flash Attention 或其等价实现,除非有特殊原因不能使用。
  2. 训练时在注意力权重上使用 dropout(原始 Transformer 用 0.1),推理时关闭。
  3. 使用 GQA(如 LLaMA 的做法)在推理效率和模型质量之间取得平衡。
  4. 注意力权重的可视化是调试模型行为的有效工具。
  5. 在混合精度训练中,注意 softmax 的数值稳定性(使用 fp32 softmax 或 Flash Attention)。

如果要对外讲,可以怎么概括

建议结构:

  1. 先用一句话解释注意力:"注意力是一种让序列中每个位置动态地从所有其他位置聚合信息的机制。"
  2. 写出缩放点积注意力的公式,解释每一项的含义。
  3. 解释为什么需要缩放(sqrt(d_k)),这是面试高频考点。
  4. 解释多头注意力的动机和实现。
  5. 讨论复杂度 O(n^2) 的瓶颈及 Flash Attention 的解决方案。
  6. 如果时间允许,延伸到 MQA/GQA 的推理优化。

常见面试追问:

  • 为什么不直接用点积而需要缩放?(防止 softmax 饱和)
  • 注意力机制与 RNN/CNN 各自的优劣?(并行性、长距离依赖、计算复杂度)
  • 如何处理超长序列的注意力?(Flash Attention、稀疏注意力、滑动窗口)
  • MQA 和 GQA 的区别和动机?(KV Cache 大小 vs 模型质量)

最后记几条

  1. 公式:Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) V —— 这是一切的起点。
  2. 缩放因子:sqrt(d_k) 不是可选的,去掉会导致训练不稳定。
  3. 多头的本质:让模型同时学习多种不同的注意力模式,而非只关注一种关系。
  4. 复杂度瓶颈:O(n^2) 的内存和时间复杂度是 Transformer 的阿喀琉斯之踵。
  5. Flash Attention:通过分块计算和在线 softmax 将内存复杂度降为 O(n),同时保持精确计算。

延伸阅读

面试高频题

QUESTION 为什么要点积注意力除以 dk\sqrt{d_k}? 假设 Q 和 K 的每个分量独立且均值为 0、方差为 1,则点积 qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i 的方差为 dkd_k。当 dkd_k 较大时(如 64),点积值量级很大,导致 softmax 输出接近 one-hot 分布,梯度趋近于零。除以 dk\sqrt{d_k} 将方差归一化为 1,使 softmax 保持合理的梯度。如果去掉缩放,大模型训练会极不稳定。

QUESTION Self-Attention、Cross-Attention、Masked Self-Attention 有什么区别?

类型 Q 来源 K/V 来源 用途
Self-Attention 同一序列 同一序列 编码器中,全局上下文交互
Masked Self-Attention 同一序列 同一序列(加掩码) 解码器中,防止看到未来
Cross-Attention 解码器 编码器输出 解码器访问输入信息

核心区别在于 Q/K/V 的来源和是否使用因果掩码。

QUESTION 多头注意力的"多头"有什么意义? 单个注意力头倾向于学习一种特定类型的依赖关系(如语法关系、语义关联、位置接近等)。多头允许模型同时学习 hh 种不同的注意力模式,每个头在独立的低维子空间中操作。直觉上类似于 CNN 中多个 filter 捕捉不同特征。原始 Transformer 用 h=8h=8,现代模型可达 32-128 头。

QUESTION Attention 有哪些变体?

变体 核心思想 复杂度
标准点积注意力 softmax(QKT/dk)Vsoftmax(QK^T/\sqrt{d_k})V O(n2d)O(n^2 d)
Sparse Attention 只计算局部+全局注意力 O(nn)O(n \sqrt{n})
Linear Attention 核函数近似替代 softmax O(nd2)O(nd^2)
Multi-Head Attention 多头并行,捕捉多种模式 O(n2d)O(n^2 d)
Multi-Query Attention 所有 Q 头共享 1 组 KV O(n2d)O(n^2 d),推理更快
Grouped-Query Attention Q 头分组,每组共享 KV O(n2d)O(n^2 d),推理/质量平衡

QUESTION Self-Attention 的计算量和参数量分别是多少?

  • 计算量4nd2+2n2d4nd^2 + 2n^2d FLOPs(其中 4nd24nd^2 来自 Q/K/V/O 四个线性投影,2n2d2n^2d 来自注意力矩阵计算和加权求和)
  • 参数量4d24d^2(四个权重矩阵 WQ,WK,WV,WOW_Q, W_K, W_V, W_O,均不考虑偏置)
  • ndn \gg d 时,注意力矩阵计算 2n2d2n^2d 占主导;当 dnd \gg n 时,线性投影 4nd24nd^2 占主导

QUESTION Attention 中的 dropout 加在哪里? 原始 Transformer 中 dropout 加在两个位置:(1) 注意力权重 softmax 之后(对注意力概率矩阵 dropout);(2) 每个子层输出 Add & Norm 之前。训练时 dropout=0.1,推理时关闭。

注意力机制类型对比

维度 点积注意力 加性注意力(Bahdanau)
计算方式 QKTQK^T(矩阵乘法) vTtanh(WQ+UK)v^T \tanh(WQ + UK)
计算速度 快,可利用 GPU 矩阵乘优化 慢,逐元素计算
理论表现 大维度下更优 极小维度可能更好
实际使用 Transformer 系列全部采用 早期 seq2seq 模型

参考资料

  • Vaswani, A. et al. (2017). "Attention Is All You Need." NeurIPS 2017. arXiv:1706.03762
  • Dao, T. et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022.
  • Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150
  • Ainslie, J. et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023.