人工智能实践(语言智能)
第2讲:Transformer

2.3 Multi-Head 与 Transformer Block

为什么多头优于单头、拼接与投影的矩阵形式、参数量估算,以及残差 + LayerNorm + FFN 构成完整 Block

一个头不够:Single-Head 的"表示瓶颈"

上一节我们看到 Scaled Dot-Product Attention 已经可以让任意两个 token 互相建立联系,但一组 (WQ,WK,WV)(W^Q, W^K, W^V) 只能学到一种相关性模式。真实语言中的关系往往是多维的:

  • 句法关系:主语-谓语、修饰-被修饰
  • 语义关系:同义、反义、上下位
  • 共指关系:代词指向的先行词
  • 远距关系:段落首尾的呼应

强迫一组投影矩阵同时捕捉所有这些关系,会让注意力分布陷入"平均脸"——既不偏向句法、也不偏向语义,什么都沾一点又都不精。

Multi-Head Attention:把表示空间切成 hh

**Multi-Head Attention(MHA)**的思想很直接:让模型在 hh 个不同子空间里并行做 Attention,然后把结果拼起来:

headi=Attention(XWiQ,XWiK,XWiV)\text{head}_i = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V) MultiHead(X)=Concat(head1,,headh)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\, W^O

关键约定是每个头的维度 dk=dmodel/hd_k = d_{\text{model}} / h——总的计算量与单头同规模 Attention 几乎相同,但得到了 hh 个独立的"视角"。

  • 典型配置:dmodel=768d_{\text{model}} = 768(BERT-base)或 10241024(BERT-large),头数 h=12h = 121616
  • 每个头的维度 dk=64d_k = 64
  • 输出拼接后用 WORdmodel×dmodelW^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}} 做一次融合投影

矩阵实现:一次大矩阵乘 + reshape

实际工程里不会真的跑 hh 次独立 Attention,而是hh 个头塞在同一个矩阵里一次算完

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.h = num_heads
        self.d_k = d_model // num_heads

        # 将 Q/K/V 的 h 个头的投影合并成一个大矩阵
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        B, N, _ = x.shape
        # (B, N, d_model) -> (B, h, N, d_k)
        def split(t): return t.view(B, N, self.h, self.d_k).transpose(1, 2)
        Q, K, V = split(self.W_q(x)), split(self.W_k(x)), split(self.W_v(x))

        scores = Q @ K.transpose(-2, -1) / (self.d_k ** 0.5)  # (B, h, N, N)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = attn @ V                                         # (B, h, N, d_k)

        # 合并所有头
        out = out.transpose(1, 2).contiguous().view(B, N, self.d_model)
        return self.W_o(out), attn

注意 W_q / W_k / W_v 都是 dmodeldmodeld_{\text{model}} \to d_{\text{model}} 的线性层,而不是 hhdmodeldkd_{\text{model}} \to d_k 的小层。通过 view + transpose 重塑出 hh 个头——这是一行 reshape 的事,但在速度上比循环快一个数量级。

参数量估算

假设 dmodel=768d_{\text{model}} = 768h=12h = 12dk=64d_k = 64

组件参数量
WQW^Q768×768=589,824768 \times 768 = 589{,}824
WKW^K768×768=589,824768 \times 768 = 589{,}824
WVW^V768×768=589,824768 \times 768 = 589{,}824
WOW^O768×768=589,824768 \times 768 = 589{,}824
MHA 小计4×76822.36M4 \times 768^2 \approx 2.36\text{M}

Multi-Head 本身不增加参数——它只是把同样的 4dmodel24 d_{\text{model}}^2 参数hh 个视角上分配。真正"吃参数"的是下面要讲的 FFN。

残差 + LayerNorm:让 Transformer 能堆到几十层

Self-Attention 不是 Transformer Block 的全部。要让模型能稳定堆叠几十甚至上百层,必须加两个关键组件:

残差连接(Residual Connection)

y=x+Sublayer(x)y = x + \text{Sublayer}(x)

即"输入直接加回输出"。这条捷径让梯度可以绕过子层直接回传,解决深层网络的梯度消失,是 ResNet 带给整个深度学习的遗产。

Layer Normalization(LN)

LN(x)=γxμσ2+ϵ+β\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

其中 μ,σ\mu, \sigma 是沿隐维度(而非 batch 维度)计算的均值和标准差。LN 把每个 token 的激活归一化到均值 0、方差 1,再用可学习的 γ,β\gamma, \beta 恢复表达能力。它保证了不同层的输入分布稳定,是 Transformer 可训练的关键。

Pre-LN vs. Post-LN——原论文用的是 Post-LN(LN 放在残差之后:LN(x + Sublayer(x))),但后续实践(包括 GPT-2 / GPT-3 / LLaMA)全部改为 Pre-LN(LN 放在子层之前:x + Sublayer(LN(x)))。Pre-LN 训练更稳定、不需要 warmup 技巧,是今天的默认选择。

Feed-Forward Network:位置独立的非线性

每个 Transformer Block 在 Self-Attention 之后还有一个 Feed-Forward Network(FFN)

FFN(x)=W2σ(W1x+b1)+b2\text{FFN}(x) = W_2 \cdot \sigma(W_1 x + b_1) + b_2

其中 W1Rdmodel×dffW_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}, W2Rdff×dmodelW_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}},通常 dff=4dmodeld_{\text{ff}} = 4 d_{\text{model}}σ\sigma 是激活函数(原论文用 ReLU,现代架构多用 GELU / SwiGLU)。

FFN 的参数量:2dmodeldff=8dmodel22 \cdot d_{\text{model}} \cdot d_{\text{ff}} = 8 d_{\text{model}}^2——是 MHA 的两倍。这也是为什么在大模型参数量估算里,FFN 通常是最大头。

FFN 虽然叫"前馈",但它是 position-wise 的:对每个位置独立做非线性变换,位置之间不交互。"位置交互"完全由 MHA 负责,FFN 只管"把每个 token 的表示往更抽象的方向推一步"。

完整的 Transformer Block

把上面的组件拼起来,就是 Transformer 的"标准 Block"(以 Pre-LN 为例):

对应的 PyTorch 代码:

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Pre-LN 残差块
        attn_out, _ = self.attn(self.ln1(x), mask)
        x = x + self.drop(attn_out)
        x = x + self.drop(self.ffn(self.ln2(x)))
        return x

单 Block 参数量表(dmodel=768,h=12,dff=3072d_{\text{model}}=768, h=12, d_{\text{ff}}=3072

组件参数量占比
Multi-Head Attention(含 WOW^O4dmodel2=2.36M4 d_{\text{model}}^2 = 2.36\text{M}33%
FFN(W1W_1 + W2W_28dmodel2=4.72M8 d_{\text{model}}^2 = 4.72\text{M}66%
LayerNorm × 24dmodel=3,0724 d_{\text{model}} = 3{,}072<1%
单 Block 合计7.1M\approx 7.1\text{M}

BERT-base 堆 12 个 Block,加上 Embedding(23M\approx 23\text{M})和输出头,总参数约 110M。这个估算方法对任何 Transformer 都适用——掌握了它,看到一个新模型(比如 Qwen3-7B)你就能立刻拆出它各部分占多少参数。

本节小结

问题答案
为什么要多头在不同子空间捕捉不同类型的关系,避免"平均脸"
多头如何实现投影矩阵一次算完,用 reshape 切分成 hh
多头会增加参数吗不会——把参数在 hh 个视角上分配
Block 的完整构造Pre-LN + MHA + 残差 + Pre-LN + FFN + 残差
参数大头在哪FFN(8dmodel28 d_{\text{model}}^2),是 MHA 的 2 倍

下一节进入另一个"非写不可"的细节——位置编码。为什么 Self-Attention 会丢失位置信息?Sinusoidal、RoPE、ALiBi 这些方案各有什么取舍?