实验 2:Transformer 最小实现与注意力可视化
两选一——用 100 行 PyTorch 实现 MiniGPT 并训练;或用 transformers 加载 BERT 并可视化多头注意力
实验概述
本讲的实验有两条路径,你可以根据兴趣和硬件条件选其中一条。两条都能把这一讲的抽象知识落到键盘上。
| 选项 | 标题 | 硬件要求 | 适合人群 |
|---|---|---|---|
| A | 100 行 PyTorch 实现 MiniGPT | GPU(RTX 3060 或以上即可);CPU 也能跑小规模 | 想彻底理解内部机制、偏算法 |
| B | BERT 多头注意力可视化 | CPU 即可,任意电脑 | 想用现成模型、偏应用与可解释性 |
推荐:时间充裕的同学两条都做——先做 B(熟悉 transformers API 与注意力图谱),再做 A(把公式和代码一一对应)。
选项 A:100 行 PyTorch 实现 MiniGPT
目标
从零实现一个极简的 Decoder-only Transformer(GPT 风格),在莎士比亚数据集(Shakespeare tiny,约 1MB)上训练一个字符级语言模型。训练完成后能生成仿莎士比亚风格的文本。
本实验致敬 Andrej Karpathy 的 nanoGPT 项目,但刻意把代码精简到 ~120 行,方便逐行对照公式。
步骤 1:环境与数据准备(10 分钟)
pip install torch tiktoken numpy matplotlib下载莎士比亚数据集:
import urllib.request, os
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
if not os.path.exists("input.txt"):
urllib.request.urlretrieve(url, "input.txt")
text = open("input.txt", "r", encoding="utf-8").read()
print(f"数据长度:{len(text)} 字符")
print(f"前 200 字:{text[:200]}")构建字符级词表:
import torch
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
data = torch.tensor(encode(text), dtype=torch.long)
n_train = int(0.9 * len(data))
train_data, val_data = data[:n_train], data[n_train:]
print(f"词表大小:{vocab_size}, 训练 token:{len(train_data)}")步骤 2:实现 MiniGPT(30 分钟)
把第 2.2、2.3 两节的公式整合成一个 120 行以内的 GPT:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ===== 超参数 =====
BLOCK_SIZE = 128 # 上下文长度
N_LAYER = 4 # Transformer Block 数量
N_HEAD = 4 # 多头数
N_EMBD = 128 # d_model
DROPOUT = 0.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class CausalSelfAttention(nn.Module):
"""Scaled Dot-Product Self-Attention with Causal Mask"""
def __init__(self, n_embd, n_head):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.d_k = n_embd // n_head
self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False) # 一次算 Q/K/V
self.proj = nn.Linear(n_embd, n_embd)
self.drop = nn.Dropout(DROPOUT)
# 下三角因果掩码
self.register_buffer(
"mask",
torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)).view(1, 1, BLOCK_SIZE, BLOCK_SIZE),
)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x) # (B, T, 3C)
q, k, v = qkv.split(C, dim=2)
q = q.view(B, T, self.n_head, self.d_k).transpose(1, 2) # (B, h, T, d_k)
k = k.view(B, T, self.n_head, self.d_k).transpose(1, 2)
v = v.view(B, T, self.n_head, self.d_k).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k) # (B, h, T, T)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
out = att @ v # (B, h, T, d_k)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.drop(self.proj(out))
class Block(nn.Module):
"""Pre-LN Transformer Block: LN → MHA → + → LN → FFN → +"""
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head)
self.ln2 = nn.LayerNorm(n_embd)
self.ffn = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(DROPOUT),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class MiniGPT(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, N_EMBD)
self.pos_emb = nn.Embedding(BLOCK_SIZE, N_EMBD) # Learned PE
self.blocks = nn.Sequential(*[Block(N_EMBD, N_HEAD) for _ in range(N_LAYER)])
self.ln_f = nn.LayerNorm(N_EMBD)
self.head = nn.Linear(N_EMBD, vocab_size, bias=False)
def forward(self, idx, targets=None):
B, T = idx.shape
pos = torch.arange(T, device=idx.device)
x = self.tok_emb(idx) + self.pos_emb(pos) # (B, T, C)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x) # (B, T, vocab)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0):
for _ in range(max_new_tokens):
idx_crop = idx[:, -BLOCK_SIZE:] # 只保留末尾 BLOCK_SIZE 个 token
logits, _ = self(idx_crop)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, 1)
idx = torch.cat([idx, next_id], dim=1)
return idx统计参数量:
model = MiniGPT(vocab_size).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f"MiniGPT 参数量:{n_params/1e6:.2f} M")
# 预期 ~0.8M 参数步骤 3:训练(20 分钟)
# ===== 训练配置 =====
BATCH_SIZE = 32
LR = 3e-4
MAX_ITERS = 3000
EVAL_EVERY = 300
def get_batch(split):
src = train_data if split == "train" else val_data
ix = torch.randint(len(src) - BLOCK_SIZE - 1, (BATCH_SIZE,))
x = torch.stack([src[i : i + BLOCK_SIZE] for i in ix])
y = torch.stack([src[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
return x.to(DEVICE), y.to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
history = {"train": [], "val": []}
for step in range(MAX_ITERS):
if step % EVAL_EVERY == 0 or step == MAX_ITERS - 1:
model.eval()
with torch.no_grad():
train_loss = model(*get_batch("train"))[1].item()
val_loss = model(*get_batch("val"))[1].item()
history["train"].append((step, train_loss))
history["val"].append((step, val_loss))
print(f"step {step}: train {train_loss:.4f} | val {val_loss:.4f}")
model.train()
x, y = get_batch("train")
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 绘制损失曲线
import matplotlib.pyplot as plt
plt.plot(*zip(*history["train"]), label="train")
plt.plot(*zip(*history["val"]), label="val")
plt.xlabel("Step"); plt.ylabel("Loss"); plt.legend(); plt.grid(True, alpha=0.3)
plt.savefig("minigpt_loss.png", dpi=150)
plt.show()步骤 4:生成与观察(10 分钟)
# ===== 生成文本 =====
model.eval()
context = torch.tensor([[stoi['\n']]], dtype=torch.long, device=DEVICE)
out = model.generate(context, max_new_tokens=500, temperature=1.0)[0].tolist()
print("========== 生成结果 ==========")
print(decode(out))预期输出(在 3000 步训练后):文本形式大致正确——有大写字母开头、以句点结束,偶尔能看到像英语单词的片段。但语义仍然混乱——这是一个 0.8M 参数的"小 GPT"能达到的极限。
延伸思考:试着对比下列实验的效果——
- 把
N_LAYER从 4 改为 1:损失会下降多少?生成质量如何变化? - 把
N_HEAD从 4 改为 1(其余不变):单头 vs. 多头的差异 - 把因果掩码去掉(变成双向 Self-Attention):loss 会下降,但生成会是什么样?(提示:会变成"看答案抄答案",val loss 骗人)
步骤 5:交付物
-
minigpt.py(或 notebook),包含完整模型与训练代码 -
minigpt_loss.png训练/验证损失曲线 - 一段 500 字符生成样例
- 一份 1 页报告:分析
N_LAYER或N_HEAD变化对性能的影响
选项 B:BERT 多头注意力可视化
目标
加载预训练的 bert-base-uncased(或中文的 bert-base-chinese),对一句有歧义的句子提取每一层、每一个头的注意力权重,并用 matplotlib 画出热力图。直观观察不同头学到了不同的语言结构。
步骤 1:安装与模型加载(5 分钟)
pip install transformers torch matplotlib seabornfrom transformers import AutoTokenizer, AutoModel
import torch
MODEL_NAME = "bert-base-chinese" # 或 "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME, output_attentions=True)
model.eval()
print(f"模型层数:{model.config.num_hidden_layers}")
print(f"每层头数:{model.config.num_attention_heads}")步骤 2:提取注意力权重(10 分钟)
# 一句有趣的中文——"他用望远镜看到了那只鸟"有句法歧义
sentence = "他用望远镜看到了那只鸟"
inputs = tokenizer(sentence, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
print("Tokens:", tokens)
with torch.no_grad():
outputs = model(**inputs)
# attentions 是一个元组,长度 = 层数;每个元素形状 (batch, n_head, seq, seq)
attentions = outputs.attentions
print(f"层数 × 形状:{len(attentions)} × {attentions[0].shape}")步骤 3:绘制单个头的热力图(15 分钟)
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
# 中文字体(Linux/Mac 通常有 SimHei / Arial Unicode MS;没有则注释掉这两行)
mpl.rcParams["font.sans-serif"] = ["Arial Unicode MS", "SimHei"]
mpl.rcParams["axes.unicode_minus"] = False
def plot_attention(layer: int, head: int):
attn = attentions[layer][0, head].numpy() # (seq, seq)
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
attn,
xticklabels=tokens,
yticklabels=tokens,
cmap="RdPu",
ax=ax,
cbar_kws={"label": "Attention weight"},
)
ax.set_title(f"Layer {layer}, Head {head}")
ax.set_xlabel("Key (被看的 token)")
ax.set_ylabel("Query (主动看的 token)")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.savefig(f"attn_L{layer}_H{head}.png", dpi=150)
plt.show()
# 先看 Layer 0, Head 0
plot_attention(0, 0)
# 再看深层的几个头
plot_attention(6, 3)
plot_attention(11, 7)如何阅读热力图:
- 横轴(Key):被关注的 token
- 纵轴(Query):发起关注的 token
- 颜色越深:注意力权重越高
[CLS]token 通常吸引大量注意力(作为句子表示)[SEP]token 在每层的注意力图里扮演特殊角色
步骤 4:比较多个头的"角色分化"(15 分钟)
# 在同一层里画出所有 12 个头,观察它们的模式差异
def plot_all_heads(layer: int):
n_heads = attentions[layer].shape[1]
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
for h in range(n_heads):
ax = axes[h // 4, h % 4]
attn = attentions[layer][0, h].numpy()
sns.heatmap(
attn,
xticklabels=tokens,
yticklabels=tokens,
cmap="RdPu",
ax=ax,
cbar=False,
)
ax.set_title(f"Head {h}")
ax.tick_params(axis="x", rotation=45)
plt.suptitle(f"BERT Layer {layer} — All Heads", fontsize=16)
plt.tight_layout()
plt.savefig(f"attn_layer{layer}_all_heads.png", dpi=150)
plt.show()
plot_all_heads(0) # 第 1 层:常偏向局部(邻近 token)
plot_all_heads(6) # 中间层:开始出现句法模式
plot_all_heads(11) # 最后一层:更集中到 [CLS]观察重点:
- 哪些头呈现"对角线模式"(每个 token 主要看自己)?
- 哪些头呈现"跨步模式"(看前一个/后一个 token)?
- 哪些头把注意力集中到
[CLS]或[SEP]? - 深层(10-11 层)和浅层(0-1 层)的注意力图有什么结构差异?
参考论文:Clark et al. (2019) What Does BERT Look At? An Analysis of BERT's Attention——他们发现某些头专门追踪直接宾语、同位语、共指等句法关系。
步骤 5:进阶 —— 追踪特定 token 的注意力流向
挑一个有趣的 token(比如"鸟"),看它在每一层、每一个头被关注得多不多:
import numpy as np
target_token = "鸟"
target_idx = tokens.index(target_token)
# 收集:(layer, head, 所有其它 token 对该 token 的平均注意力)
scores = np.zeros((model.config.num_hidden_layers, model.config.num_attention_heads))
for l in range(model.config.num_hidden_layers):
for h in range(model.config.num_attention_heads):
# 所有 token 的 Query 对 target_token 这一 Key 的权重的均值
scores[l, h] = attentions[l][0, h, :, target_idx].mean().item()
fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(scores, cmap="RdPu", ax=ax,
xticklabels=range(model.config.num_attention_heads),
yticklabels=range(model.config.num_hidden_layers))
ax.set_xlabel("Head")
ax.set_ylabel("Layer")
ax.set_title(f"注意力流向 '{target_token}' 的平均权重")
plt.tight_layout()
plt.savefig(f"flow_to_{target_token}.png", dpi=150)
plt.show()步骤 6:交付物
- 至少 3 张单头热力图(不同层、不同头)
- 一张"全头图"(某一层的 12 个头并排)
- 一张"token 注意力流向"图
- 一份 1 页观察报告:
- 选一个头并解释它可能学到了什么结构(句法?局部?还是
[CLS]聚合?) - 浅层和深层的注意力模式有什么差异?
- 结合本讲 2.3 节"多头为什么好于单头"的讨论,你是否在实验中看到了不同头的"角色分化"?
- 选一个头并解释它可能学到了什么结构(句法?局部?还是
延伸:两条路径都跑通后的挑战题(可选)
任务:把 MiniGPT 的 N_LAYER 从 4 增加到 8,或把 N_EMBD 从 128 增加到 256,观察:
- 参数量的变化(用公式估算,再打印验证)
- 训练收敛速度
- 生成文本质量
思考:为什么 loss 下降不是线性的?这和 Chinchilla 规模律有什么联系?
任务:把 MiniGPT 的 Learned PE 换成 RoPE(参考 LLaMA 源码)。
思考:
- 如果训练时
BLOCK_SIZE=128,推理时能处理BLOCK_SIZE=256的输入吗?为什么 RoPE 可以,Learned PE 不行? - 尝试在推理时用 NTK-aware RoPE(改大频率基),看看外推效果。
任务:实现 MiniGPT 生成时的 KV Cache——每步生成只算新 token 的 Q,复用之前的 K/V。
思考:
- KV Cache 把生成复杂度从什么降到什么?
- KV Cache 只对 Decoder-only 有用,为什么 Encoder-only(BERT)不需要?
实验评分标准
| 维度 | 权重 |
|---|---|
| 代码能跑通,产出正确 | 40% |
| 报告观察准确,联系课程理论 | 30% |
| 可视化/图表清晰 | 20% |
| 挑战题或延伸思考 | 10% |
常见坑:
- 选项 A 中
causal mask如果忘记,模型会"看答案抄答案",val loss 骗人地低,但生成时完全乱 - 选项 A 中
apply_chat_template不适用(这是字符级模型),直接用字符 ID - 选项 B 中
output_attentions=True必须在from_pretrained或forward时显式开启 - 中文字体缺失会让热力图标签显示为方块——Mac 推荐 "Arial Unicode MS",Linux 可装
fonts-noto-cjk