人工智能实践(语言智能)
第2讲:Transformer

实验 2:Transformer 最小实现与注意力可视化

两选一——用 100 行 PyTorch 实现 MiniGPT 并训练;或用 transformers 加载 BERT 并可视化多头注意力

实验概述

本讲的实验有两条路径,你可以根据兴趣和硬件条件选其中一条。两条都能把这一讲的抽象知识落到键盘上

选项标题硬件要求适合人群
A100 行 PyTorch 实现 MiniGPTGPU(RTX 3060 或以上即可);CPU 也能跑小规模想彻底理解内部机制、偏算法
BBERT 多头注意力可视化CPU 即可,任意电脑想用现成模型、偏应用与可解释性

推荐:时间充裕的同学两条都做——先做 B(熟悉 transformers API 与注意力图谱),再做 A(把公式和代码一一对应)。


选项 A:100 行 PyTorch 实现 MiniGPT

目标

从零实现一个极简的 Decoder-only Transformer(GPT 风格),在莎士比亚数据集(Shakespeare tiny,约 1MB)上训练一个字符级语言模型。训练完成后能生成仿莎士比亚风格的文本。

本实验致敬 Andrej Karpathy 的 nanoGPT 项目,但刻意把代码精简到 ~120 行,方便逐行对照公式。

步骤 1:环境与数据准备(10 分钟)

pip install torch tiktoken numpy matplotlib

下载莎士比亚数据集:

import urllib.request, os
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
if not os.path.exists("input.txt"):
    urllib.request.urlretrieve(url, "input.txt")

text = open("input.txt", "r", encoding="utf-8").read()
print(f"数据长度:{len(text)} 字符")
print(f"前 200 字:{text[:200]}")

构建字符级词表:

import torch

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n_train = int(0.9 * len(data))
train_data, val_data = data[:n_train], data[n_train:]
print(f"词表大小:{vocab_size}, 训练 token:{len(train_data)}")

步骤 2:实现 MiniGPT(30 分钟)

把第 2.2、2.3 两节的公式整合成一个 120 行以内的 GPT:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ===== 超参数 =====
BLOCK_SIZE = 128       # 上下文长度
N_LAYER = 4            # Transformer Block 数量
N_HEAD = 4             # 多头数
N_EMBD = 128           # d_model
DROPOUT = 0.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class CausalSelfAttention(nn.Module):
    """Scaled Dot-Product Self-Attention with Causal Mask"""
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d_k = n_embd // n_head
        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)  # 一次算 Q/K/V
        self.proj = nn.Linear(n_embd, n_embd)
        self.drop = nn.Dropout(DROPOUT)
        # 下三角因果掩码
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)).view(1, 1, BLOCK_SIZE, BLOCK_SIZE),
        )

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x)                                       # (B, T, 3C)
        q, k, v = qkv.split(C, dim=2)
        q = q.view(B, T, self.n_head, self.d_k).transpose(1, 2) # (B, h, T, d_k)
        k = k.view(B, T, self.n_head, self.d_k).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.d_k).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)   # (B, h, T, T)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)

        out = att @ v                                           # (B, h, T, d_k)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.drop(self.proj(out))


class Block(nn.Module):
    """Pre-LN Transformer Block: LN → MHA → + → LN → FFN → +"""
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ffn = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(DROPOUT),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


class MiniGPT(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, N_EMBD)
        self.pos_emb = nn.Embedding(BLOCK_SIZE, N_EMBD)  # Learned PE
        self.blocks = nn.Sequential(*[Block(N_EMBD, N_HEAD) for _ in range(N_LAYER)])
        self.ln_f = nn.LayerNorm(N_EMBD)
        self.head = nn.Linear(N_EMBD, vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        pos = torch.arange(T, device=idx.device)
        x = self.tok_emb(idx) + self.pos_emb(pos)  # (B, T, C)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)                       # (B, T, vocab)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
            )
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0):
        for _ in range(max_new_tokens):
            idx_crop = idx[:, -BLOCK_SIZE:]        # 只保留末尾 BLOCK_SIZE 个 token
            logits, _ = self(idx_crop)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, 1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx

统计参数量:

model = MiniGPT(vocab_size).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f"MiniGPT 参数量:{n_params/1e6:.2f} M")
# 预期 ~0.8M 参数

步骤 3:训练(20 分钟)

# ===== 训练配置 =====
BATCH_SIZE = 32
LR = 3e-4
MAX_ITERS = 3000
EVAL_EVERY = 300

def get_batch(split):
    src = train_data if split == "train" else val_data
    ix = torch.randint(len(src) - BLOCK_SIZE - 1, (BATCH_SIZE,))
    x = torch.stack([src[i : i + BLOCK_SIZE] for i in ix])
    y = torch.stack([src[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

history = {"train": [], "val": []}
for step in range(MAX_ITERS):
    if step % EVAL_EVERY == 0 or step == MAX_ITERS - 1:
        model.eval()
        with torch.no_grad():
            train_loss = model(*get_batch("train"))[1].item()
            val_loss = model(*get_batch("val"))[1].item()
        history["train"].append((step, train_loss))
        history["val"].append((step, val_loss))
        print(f"step {step}: train {train_loss:.4f} | val {val_loss:.4f}")
        model.train()

    x, y = get_batch("train")
    _, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 绘制损失曲线
import matplotlib.pyplot as plt
plt.plot(*zip(*history["train"]), label="train")
plt.plot(*zip(*history["val"]), label="val")
plt.xlabel("Step"); plt.ylabel("Loss"); plt.legend(); plt.grid(True, alpha=0.3)
plt.savefig("minigpt_loss.png", dpi=150)
plt.show()

步骤 4:生成与观察(10 分钟)

# ===== 生成文本 =====
model.eval()
context = torch.tensor([[stoi['\n']]], dtype=torch.long, device=DEVICE)
out = model.generate(context, max_new_tokens=500, temperature=1.0)[0].tolist()
print("========== 生成结果 ==========")
print(decode(out))

预期输出(在 3000 步训练后):文本形式大致正确——有大写字母开头、以句点结束,偶尔能看到像英语单词的片段。但语义仍然混乱——这是一个 0.8M 参数的"小 GPT"能达到的极限。

延伸思考:试着对比下列实验的效果——

  1. N_LAYER 从 4 改为 1:损失会下降多少?生成质量如何变化?
  2. N_HEAD 从 4 改为 1(其余不变):单头 vs. 多头的差异
  3. 把因果掩码去掉(变成双向 Self-Attention):loss 会下降,但生成会是什么样?(提示:会变成"看答案抄答案",val loss 骗人)

步骤 5:交付物

  • minigpt.py(或 notebook),包含完整模型与训练代码
  • minigpt_loss.png 训练/验证损失曲线
  • 一段 500 字符生成样例
  • 一份 1 页报告:分析 N_LAYERN_HEAD 变化对性能的影响

选项 B:BERT 多头注意力可视化

目标

加载预训练的 bert-base-uncased(或中文的 bert-base-chinese),对一句有歧义的句子提取每一层、每一个头的注意力权重,并用 matplotlib 画出热力图。直观观察不同头学到了不同的语言结构

步骤 1:安装与模型加载(5 分钟)

pip install transformers torch matplotlib seaborn
from transformers import AutoTokenizer, AutoModel
import torch

MODEL_NAME = "bert-base-chinese"   # 或 "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME, output_attentions=True)
model.eval()
print(f"模型层数:{model.config.num_hidden_layers}")
print(f"每层头数:{model.config.num_attention_heads}")

步骤 2:提取注意力权重(10 分钟)

# 一句有趣的中文——"他用望远镜看到了那只鸟"有句法歧义
sentence = "他用望远镜看到了那只鸟"
inputs = tokenizer(sentence, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
print("Tokens:", tokens)

with torch.no_grad():
    outputs = model(**inputs)

# attentions 是一个元组,长度 = 层数;每个元素形状 (batch, n_head, seq, seq)
attentions = outputs.attentions
print(f"层数 × 形状:{len(attentions)} × {attentions[0].shape}")

步骤 3:绘制单个头的热力图(15 分钟)

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

# 中文字体(Linux/Mac 通常有 SimHei / Arial Unicode MS;没有则注释掉这两行)
mpl.rcParams["font.sans-serif"] = ["Arial Unicode MS", "SimHei"]
mpl.rcParams["axes.unicode_minus"] = False


def plot_attention(layer: int, head: int):
    attn = attentions[layer][0, head].numpy()   # (seq, seq)
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(
        attn,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap="RdPu",
        ax=ax,
        cbar_kws={"label": "Attention weight"},
    )
    ax.set_title(f"Layer {layer}, Head {head}")
    ax.set_xlabel("Key (被看的 token)")
    ax.set_ylabel("Query (主动看的 token)")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(f"attn_L{layer}_H{head}.png", dpi=150)
    plt.show()


# 先看 Layer 0, Head 0
plot_attention(0, 0)

# 再看深层的几个头
plot_attention(6, 3)
plot_attention(11, 7)

如何阅读热力图

  • 横轴(Key):被关注的 token
  • 纵轴(Query):发起关注的 token
  • 颜色越深:注意力权重越高
  • [CLS] token 通常吸引大量注意力(作为句子表示)
  • [SEP] token 在每层的注意力图里扮演特殊角色

步骤 4:比较多个头的"角色分化"(15 分钟)

# 在同一层里画出所有 12 个头,观察它们的模式差异
def plot_all_heads(layer: int):
    n_heads = attentions[layer].shape[1]
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    for h in range(n_heads):
        ax = axes[h // 4, h % 4]
        attn = attentions[layer][0, h].numpy()
        sns.heatmap(
            attn,
            xticklabels=tokens,
            yticklabels=tokens,
            cmap="RdPu",
            ax=ax,
            cbar=False,
        )
        ax.set_title(f"Head {h}")
        ax.tick_params(axis="x", rotation=45)
    plt.suptitle(f"BERT Layer {layer} — All Heads", fontsize=16)
    plt.tight_layout()
    plt.savefig(f"attn_layer{layer}_all_heads.png", dpi=150)
    plt.show()


plot_all_heads(0)     # 第 1 层:常偏向局部(邻近 token)
plot_all_heads(6)     # 中间层:开始出现句法模式
plot_all_heads(11)    # 最后一层:更集中到 [CLS]

观察重点

  • 哪些头呈现"对角线模式"(每个 token 主要看自己)?
  • 哪些头呈现"跨步模式"(看前一个/后一个 token)?
  • 哪些头把注意力集中到 [CLS][SEP]
  • 深层(10-11 层)和浅层(0-1 层)的注意力图有什么结构差异?

参考论文:Clark et al. (2019) What Does BERT Look At? An Analysis of BERT's Attention——他们发现某些头专门追踪直接宾语、同位语、共指等句法关系。

步骤 5:进阶 —— 追踪特定 token 的注意力流向

挑一个有趣的 token(比如"鸟"),看它在每一层、每一个头被关注得多不多:

import numpy as np

target_token = "鸟"
target_idx = tokens.index(target_token)

# 收集:(layer, head, 所有其它 token 对该 token 的平均注意力)
scores = np.zeros((model.config.num_hidden_layers, model.config.num_attention_heads))
for l in range(model.config.num_hidden_layers):
    for h in range(model.config.num_attention_heads):
        # 所有 token 的 Query 对 target_token 这一 Key 的权重的均值
        scores[l, h] = attentions[l][0, h, :, target_idx].mean().item()

fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(scores, cmap="RdPu", ax=ax,
            xticklabels=range(model.config.num_attention_heads),
            yticklabels=range(model.config.num_hidden_layers))
ax.set_xlabel("Head")
ax.set_ylabel("Layer")
ax.set_title(f"注意力流向 '{target_token}' 的平均权重")
plt.tight_layout()
plt.savefig(f"flow_to_{target_token}.png", dpi=150)
plt.show()

步骤 6:交付物

  • 至少 3 张单头热力图(不同层、不同头)
  • 一张"全头图"(某一层的 12 个头并排)
  • 一张"token 注意力流向"图
  • 一份 1 页观察报告:
    • 选一个头并解释它可能学到了什么结构(句法?局部?还是 [CLS] 聚合?)
    • 浅层和深层的注意力模式有什么差异?
    • 结合本讲 2.3 节"多头为什么好于单头"的讨论,你是否在实验中看到了不同头的"角色分化"?

延伸:两条路径都跑通后的挑战题(可选)

任务:把 MiniGPT 的 N_LAYER 从 4 增加到 8,或把 N_EMBD 从 128 增加到 256,观察:

  • 参数量的变化(用公式估算,再打印验证)
  • 训练收敛速度
  • 生成文本质量

思考:为什么 loss 下降不是线性的?这和 Chinchilla 规模律有什么联系?

任务:把 MiniGPT 的 Learned PE 换成 RoPE(参考 LLaMA 源码)。

思考

  • 如果训练时 BLOCK_SIZE=128,推理时能处理 BLOCK_SIZE=256 的输入吗?为什么 RoPE 可以,Learned PE 不行?
  • 尝试在推理时用 NTK-aware RoPE(改大频率基),看看外推效果。

任务:实现 MiniGPT 生成时的 KV Cache——每步生成只算新 token 的 Q,复用之前的 K/V。

思考

  • KV Cache 把生成复杂度从什么降到什么?
  • KV Cache 只对 Decoder-only 有用,为什么 Encoder-only(BERT)不需要?

实验评分标准

维度权重
代码能跑通,产出正确40%
报告观察准确,联系课程理论30%
可视化/图表清晰20%
挑战题或延伸思考10%

常见坑

  1. 选项 A 中 causal mask 如果忘记,模型会"看答案抄答案",val loss 骗人地低,但生成时完全乱
  2. 选项 A 中 apply_chat_template 不适用(这是字符级模型),直接用字符 ID
  3. 选项 B 中 output_attentions=True 必须在 from_pretrainedforward 时显式开启
  4. 中文字体缺失会让热力图标签显示为方块——Mac 推荐 "Arial Unicode MS",Linux 可装 fonts-noto-cjk