人工智能实践(语言智能)
第2讲:Transformer

2.2 Self-Attention

Scaled Dot-Product Attention 的公式推导、矩阵视角、√dₖ 缩放的必要性,以及一个逐步展开的数值例子

从 Query-Key-Value 看 Attention

Self-Attention 的核心是把每个 token 同时扮演三个角色:

  • Query(Q)——"我想找什么样的信息?"
  • Key(K)——"我身上有什么样的信息?"
  • Value(V)——"如果你找到我,我实际能贡献什么?"

这和数据库检索(或者字典查找)的类比非常贴切:Query 是查询词,Key 是索引,Value 是真实内容;只不过在 Attention 中,匹配不是硬匹配,而是连续的、可导的"软匹配"。

对一个输入序列 XRn×dmodelX \in \mathbb{R}^{n \times d_{\text{model}}}nn 为序列长度,dmodeld_{\text{model}} 为隐维度),用三组投影矩阵得到:

Q=XWQ,K=XWK,V=XWVQ = X W^Q, \quad K = X W^K, \quad V = X W^V

其中 WQ,WKRdmodel×dkW^Q, W^K \in \mathbb{R}^{d_{\text{model}} \times d_k}WVRdmodel×dvW^V \in \mathbb{R}^{d_{\text{model}} \times d_v}。通常取 dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / hhh 为头数,见下节)。

Scaled Dot-Product Attention 公式

Vaswani 等人给出的核心公式是:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

这个公式可以拆成四步:

计算相似度分数

QKRn×nQK^\top \in \mathbb{R}^{n \times n}。第 (i,j)(i, j) 个元素是 Query qiq_i 与 Key kjk_j 的点积,反映 token ii 对 token jj 的关注度(未归一化)。

缩放(Scale)

除以 dk\sqrt{d_k}。为什么要除这个数?稍后专门讲。

Softmax 归一化

对每一行做 softmax,得到注意力权重 ARn×nA \in \mathbb{R}^{n \times n},每行元素非负且和为 1。AijA_{ij} 表示"token ii 有多大比例在看 token jj"。

加权求和 Value

AVRn×dvAV \in \mathbb{R}^{n \times d_v}。第 ii 行是按注意力权重加权的 Value 向量——这就是 token ii 经过一次 Self-Attention 后的新表示。

为什么要除以 dk\sqrt{d_k}

这个缩放因子是 Transformer 原论文的重要细节,也是考试和面试中的高频考点。

直觉推导:假设 QQKK 的每一维都是均值为 0、方差为 1 的独立随机变量,那么它们的点积

qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i

的方差是 dkd_kdkd_k 个方差为 1 的独立项之和)。当 dk=64d_k = 64 时,点积的标准差约为 8;当 dk=128d_k = 128 时,标准差约为 11.3。

问题:softmax 对输入的数值量级极为敏感。如果输入的量级过大,softmax 的输出就会集中在一个峰值上,接近 one-hot 分布——梯度几乎全为零,无法训练

softmax(z)i=ezijezj\text{softmax}(z)_i = \frac{e^{z_i}}{\sum_j e^{z_j}}

max(z)second(z)1\max(z) - \text{second}(z) \gg 1 时,梯度对大多数输入都几乎为零。

解决:除以 dk\sqrt{d_k} 把点积的方差拉回 1,让 softmax 的输入停留在梯度良好的区间。

这就是 "Scaled" Dot-Product 里 "Scaled" 的由来。在后续位置编码(如 RoPE)或 FlashAttention 的实现中,这个缩放因子依然保留。如果你在实现时忘掉 dk\sqrt{d_k},训练通常能起来但收敛非常慢,且容易发散。

矩阵视角:三步矩阵乘法

把所有公式连起来,Self-Attention 本质上就是 3 个矩阵乘法 + 1 个 softmax

这张图非常关键——所有 token 在一次矩阵乘法里同时完成 Attention。对比 RNN 一步一步算隐状态,差距一目了然。

一个数值小例子:3 个 token,dk=2d_k = 2

为了把公式落地,我们算一个最小的例子。设 n=3n = 3, dk=dv=2d_k = d_v = 2,并给出 Q,K,VQ, K, V(假设投影已完成):

Q=(100111),K=(100110),V=(10001055)Q = \begin{pmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{pmatrix},\quad K = \begin{pmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \end{pmatrix},\quad V = \begin{pmatrix} 10 & 0 \\ 0 & 10 \\ 5 & 5 \end{pmatrix}

第一步:QKQK^\top

QK=(101010111)QK^\top = \begin{pmatrix} 1 & 0 & 1 \\ 0 & 1 & 0 \\ 1 & 1 & 1 \end{pmatrix}

读一下第 3 行:q3=(1,1)q_3 = (1, 1)k1,k2,k3k_1, k_2, k_3 分别点积得 1, 1, 1——token 3 对三个 token 的原始相似度都相同。

第二步:缩放 dk=21.414\sqrt{d_k} = \sqrt{2} \approx 1.414

QK2(0.70700.70700.70700.7070.7070.707)\frac{QK^\top}{\sqrt{2}} \approx \begin{pmatrix} 0.707 & 0 & 0.707 \\ 0 & 0.707 & 0 \\ 0.707 & 0.707 & 0.707 \end{pmatrix}

第三步:softmax 按行归一化

对第 1 行 (0.707,0,0.707)(0.707, 0, 0.707):指数为 (2.028,1,2.028)(2.028, 1, 2.028),和为 5.0565.056,归一化后得 (0.401,0.198,0.401)(0.401, 0.198, 0.401)

第 2 行 (0,0.707,0)(0, 0.707, 0):指数为 (1,2.028,1)(1, 2.028, 1),和为 4.0284.028,归一化得 (0.248,0.504,0.248)(0.248, 0.504, 0.248)

第 3 行 (0.707,0.707,0.707)(0.707, 0.707, 0.707):三项相等,归一化得 (1/3,1/3,1/3)(1/3, 1/3, 1/3)

A(0.4010.1980.4010.2480.5040.2480.3330.3330.333)A \approx \begin{pmatrix} 0.401 & 0.198 & 0.401 \\ 0.248 & 0.504 & 0.248 \\ 0.333 & 0.333 & 0.333 \end{pmatrix}

第四步:AVAV

逐行计算:

  • 第 1 行:0.401(10,0)+0.198(0,10)+0.401(5,5)=(6.02,3.99)0.401 \cdot (10, 0) + 0.198 \cdot (0, 10) + 0.401 \cdot (5, 5) = (6.02, 3.99)
  • 第 2 行:0.248(10,0)+0.504(0,10)+0.248(5,5)=(3.72,6.28)0.248 \cdot (10, 0) + 0.504 \cdot (0, 10) + 0.248 \cdot (5, 5) = (3.72, 6.28)
  • 第 3 行:13(10+0+5,0+10+5)=(5.0,5.0)\frac{1}{3}(10 + 0 + 5, 0 + 10 + 5) = (5.0, 5.0)
Attention(Q,K,V)(6.023.993.726.285.005.00)\text{Attention}(Q, K, V) \approx \begin{pmatrix} 6.02 & 3.99 \\ 3.72 & 6.28 \\ 5.00 & 5.00 \end{pmatrix}

读图解义:token 3 的 Query 对所有 Key 同样相似,最终输出是三个 Value 的均值 (5,5)(5, 5);token 1 的 Query 和 Key 1、Key 3 最匹配,输出就偏向 V1V_1V3V_3

极简 PyTorch 实现

把公式直接翻译成 PyTorch(不含 Multi-Head 和 Mask,后续章节再加):

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, n, d_k)
    K: (batch, n, d_k)
    V: (batch, n, d_v)
    """
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)  # (b, n, n)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    attn = F.softmax(scores, dim=-1)
    out = torch.matmul(attn, V)  # (b, n, d_v)
    return out, attn

三个实现陷阱:(1)不要忘记 dk\sqrt{d_k} 缩放;(2)mask 要在 softmax 之前把被屏蔽位置置为 -\infty(而不是 0);(3)transpose(-2, -1) 只换最后两维,保持 batch 维不动。

本节小结

要点内容
Attention 的直觉数据库检索:Q 是查询、K 是索引、V 是内容
核心公式Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
为什么除 dk\sqrt{d_k}把点积方差拉回 1,避免 softmax 梯度消失
计算复杂度O(n2d)O(n^2 \cdot d):空间在 n×nn \times n 的注意力矩阵
并行性所有 token 在一次矩阵乘法里同时完成

下一节讨论为什么一个头不够——引入 Multi-Head Attention 和 Transformer Block 的完整构造。