2.3 Multi-Head 与 Transformer Block
为什么多头优于单头、拼接与投影的矩阵形式、参数量估算,以及残差 + LayerNorm + FFN 构成完整 Block
一个头不够:Single-Head 的"表示瓶颈"
上一节我们看到 Scaled Dot-Product Attention 已经可以让任意两个 token 互相建立联系,但一组 只能学到一种相关性模式。真实语言中的关系往往是多维的:
- 句法关系:主语-谓语、修饰-被修饰
- 语义关系:同义、反义、上下位
- 共指关系:代词指向的先行词
- 远距关系:段落首尾的呼应
强迫一组投影矩阵同时捕捉所有这些关系,会让注意力分布陷入"平均脸"——既不偏向句法、也不偏向语义,什么都沾一点又都不精。
Multi-Head Attention:把表示空间切成 份
**Multi-Head Attention(MHA)**的思想很直接:让模型在 个不同子空间里并行做 Attention,然后把结果拼起来:
关键约定是每个头的维度 ——总的计算量与单头同规模 Attention 几乎相同,但得到了 个独立的"视角"。
- 典型配置:(BERT-base)或 (BERT-large),头数 或
- 每个头的维度
- 输出拼接后用 做一次融合投影
矩阵实现:一次大矩阵乘 + reshape
实际工程里不会真的跑 次独立 Attention,而是把 个头塞在同一个矩阵里一次算完:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.h = num_heads
self.d_k = d_model // num_heads
# 将 Q/K/V 的 h 个头的投影合并成一个大矩阵
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
B, N, _ = x.shape
# (B, N, d_model) -> (B, h, N, d_k)
def split(t): return t.view(B, N, self.h, self.d_k).transpose(1, 2)
Q, K, V = split(self.W_q(x)), split(self.W_k(x)), split(self.W_v(x))
scores = Q @ K.transpose(-2, -1) / (self.d_k ** 0.5) # (B, h, N, N)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = attn @ V # (B, h, N, d_k)
# 合并所有头
out = out.transpose(1, 2).contiguous().view(B, N, self.d_model)
return self.W_o(out), attn注意 W_q / W_k / W_v 都是 的线性层,而不是 个 的小层。通过 view + transpose 重塑出 个头——这是一行 reshape 的事,但在速度上比循环快一个数量级。
参数量估算
假设 ,,:
| 组件 | 参数量 |
|---|---|
| MHA 小计 |
Multi-Head 本身不增加参数——它只是把同样的 参数在 个视角上分配。真正"吃参数"的是下面要讲的 FFN。
残差 + LayerNorm:让 Transformer 能堆到几十层
Self-Attention 不是 Transformer Block 的全部。要让模型能稳定堆叠几十甚至上百层,必须加两个关键组件:
残差连接(Residual Connection)
即"输入直接加回输出"。这条捷径让梯度可以绕过子层直接回传,解决深层网络的梯度消失,是 ResNet 带给整个深度学习的遗产。
Layer Normalization(LN)
其中 是沿隐维度(而非 batch 维度)计算的均值和标准差。LN 把每个 token 的激活归一化到均值 0、方差 1,再用可学习的 恢复表达能力。它保证了不同层的输入分布稳定,是 Transformer 可训练的关键。
Pre-LN vs. Post-LN——原论文用的是 Post-LN(LN 放在残差之后:LN(x + Sublayer(x))),但后续实践(包括 GPT-2 / GPT-3 / LLaMA)全部改为 Pre-LN(LN 放在子层之前:x + Sublayer(LN(x)))。Pre-LN 训练更稳定、不需要 warmup 技巧,是今天的默认选择。
Feed-Forward Network:位置独立的非线性
每个 Transformer Block 在 Self-Attention 之后还有一个 Feed-Forward Network(FFN):
其中 , ,通常 。 是激活函数(原论文用 ReLU,现代架构多用 GELU / SwiGLU)。
FFN 的参数量:——是 MHA 的两倍。这也是为什么在大模型参数量估算里,FFN 通常是最大头。
FFN 虽然叫"前馈",但它是 position-wise 的:对每个位置独立做非线性变换,位置之间不交互。"位置交互"完全由 MHA 负责,FFN 只管"把每个 token 的表示往更抽象的方向推一步"。
完整的 Transformer Block
把上面的组件拼起来,就是 Transformer 的"标准 Block"(以 Pre-LN 为例):
对应的 PyTorch 代码:
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, num_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
self.drop = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Pre-LN 残差块
attn_out, _ = self.attn(self.ln1(x), mask)
x = x + self.drop(attn_out)
x = x + self.drop(self.ffn(self.ln2(x)))
return x单 Block 参数量表()
| 组件 | 参数量 | 占比 |
|---|---|---|
| Multi-Head Attention(含 ) | 33% | |
| FFN( + ) | 66% | |
| LayerNorm × 2 | <1% | |
| 单 Block 合计 | — |
BERT-base 堆 12 个 Block,加上 Embedding()和输出头,总参数约 110M。这个估算方法对任何 Transformer 都适用——掌握了它,看到一个新模型(比如 Qwen3-7B)你就能立刻拆出它各部分占多少参数。
本节小结
| 问题 | 答案 |
|---|---|
| 为什么要多头 | 在不同子空间捕捉不同类型的关系,避免"平均脸" |
| 多头如何实现 | 投影矩阵一次算完,用 reshape 切分成 组 |
| 多头会增加参数吗 | 不会——把参数在 个视角上分配 |
| Block 的完整构造 | Pre-LN + MHA + 残差 + Pre-LN + FFN + 残差 |
| 参数大头在哪 | FFN(),是 MHA 的 2 倍 |
下一节进入另一个"非写不可"的细节——位置编码。为什么 Self-Attention 会丢失位置信息?Sinusoidal、RoPE、ALiBi 这些方案各有什么取舍?