Self-Attention 的核心是把每个 token 同时扮演三个角色:
- Query(Q)——"我想找什么样的信息?"
- Key(K)——"我身上有什么样的信息?"
- Value(V)——"如果你找到我,我实际能贡献什么?"
这和数据库检索(或者字典查找)的类比非常贴切:Query 是查询词,Key 是索引,Value 是真实内容;只不过在 Attention 中,匹配不是硬匹配,而是连续的、可导的"软匹配"。
对一个输入序列 X∈Rn×dmodel(n 为序列长度,dmodel 为隐维度),用三组投影矩阵得到:
Q=XWQ,K=XWK,V=XWV
其中 WQ,WK∈Rdmodel×dk,WV∈Rdmodel×dv。通常取 dk=dv=dmodel/h(h 为头数,见下节)。
Vaswani 等人给出的核心公式是:
Attention(Q,K,V)=softmax(dkQK⊤)V
这个公式可以拆成四步:
QK⊤∈Rn×n。第 (i,j) 个元素是 Query qi 与 Key kj 的点积,反映 token i 对 token j 的关注度(未归一化)。
除以 dk。为什么要除这个数?稍后专门讲。
对每一行做 softmax,得到注意力权重 A∈Rn×n,每行元素非负且和为 1。Aij 表示"token i 有多大比例在看 token j"。
AV∈Rn×dv。第 i 行是按注意力权重加权的 Value 向量——这就是 token i 经过一次 Self-Attention 后的新表示。
这个缩放因子是 Transformer 原论文的重要细节,也是考试和面试中的高频考点。
直觉推导:假设 Q 和 K 的每一维都是均值为 0、方差为 1 的独立随机变量,那么它们的点积
q⋅k=i=1∑dkqiki
的方差是 dk(dk 个方差为 1 的独立项之和)。当 dk=64 时,点积的标准差约为 8;当 dk=128 时,标准差约为 11.3。
问题:softmax 对输入的数值量级极为敏感。如果输入的量级过大,softmax 的输出就会集中在一个峰值上,接近 one-hot 分布——梯度几乎全为零,无法训练。
softmax(z)i=∑jezjezi
当 max(z)−second(z)≫1 时,梯度对大多数输入都几乎为零。
解决:除以 dk 把点积的方差拉回 1,让 softmax 的输入停留在梯度良好的区间。
这就是 "Scaled" Dot-Product 里 "Scaled" 的由来。在后续位置编码(如 RoPE)或 FlashAttention 的实现中,这个缩放因子依然保留。如果你在实现时忘掉 dk,训练通常能起来但收敛非常慢,且容易发散。
把所有公式连起来,Self-Attention 本质上就是 3 个矩阵乘法 + 1 个 softmax:
这张图非常关键——所有 token 在一次矩阵乘法里同时完成 Attention。对比 RNN 一步一步算隐状态,差距一目了然。
为了把公式落地,我们算一个最小的例子。设 n=3, dk=dv=2,并给出 Q,K,V(假设投影已完成):
Q=101011,K=101010,V=10050105
第一步:QK⊤
QK⊤=101011101
读一下第 3 行:q3=(1,1) 与 k1,k2,k3 分别点积得 1, 1, 1——token 3 对三个 token 的原始相似度都相同。
第二步:缩放 dk=2≈1.414
2QK⊤≈0.70700.70700.7070.7070.70700.707
第三步:softmax 按行归一化
对第 1 行 (0.707,0,0.707):指数为 (2.028,1,2.028),和为 5.056,归一化后得 (0.401,0.198,0.401)。
第 2 行 (0,0.707,0):指数为 (1,2.028,1),和为 4.028,归一化得 (0.248,0.504,0.248)。
第 3 行 (0.707,0.707,0.707):三项相等,归一化得 (1/3,1/3,1/3)。
A≈0.4010.2480.3330.1980.5040.3330.4010.2480.333
第四步:AV
逐行计算:
- 第 1 行:0.401⋅(10,0)+0.198⋅(0,10)+0.401⋅(5,5)=(6.02,3.99)
- 第 2 行:0.248⋅(10,0)+0.504⋅(0,10)+0.248⋅(5,5)=(3.72,6.28)
- 第 3 行:31(10+0+5,0+10+5)=(5.0,5.0)
Attention(Q,K,V)≈6.023.725.003.996.285.00
读图解义:token 3 的 Query 对所有 Key 同样相似,最终输出是三个 Value 的均值 (5,5);token 1 的 Query 和 Key 1、Key 3 最匹配,输出就偏向 V1 和 V3。
把公式直接翻译成 PyTorch(不含 Multi-Head 和 Mask,后续章节再加):
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (batch, n, d_k)
K: (batch, n, d_k)
V: (batch, n, d_v)
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5) # (b, n, n)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V) # (b, n, d_v)
return out, attn
三个实现陷阱:(1)不要忘记 dk 缩放;(2)mask 要在 softmax 之前把被屏蔽位置置为 −∞(而不是 0);(3)transpose(-2, -1) 只换最后两维,保持 batch 维不动。
| 要点 | 内容 |
|---|
| Attention 的直觉 | 数据库检索:Q 是查询、K 是索引、V 是内容 |
| 核心公式 | Attention(Q,K,V)=softmax(dkQK⊤)V |
| 为什么除 dk | 把点积方差拉回 1,避免 softmax 梯度消失 |
| 计算复杂度 | O(n2⋅d):空间在 n×n 的注意力矩阵 |
| 并行性 | 所有 token 在一次矩阵乘法里同时完成 |
下一节讨论为什么一个头不够——引入 Multi-Head Attention 和 Transformer Block 的完整构造。