人工智能实践(语言智能)
第6讲:大模型微调

6.2 监督微调(SFT)

完整 SFT 流水线:加载 base → chat template → tokenize + masking → DataCollator → Trainer → 监控

SFT 完整流水线

监督微调(Supervised Fine-Tuning, SFT)的工程实现比名字听起来要繁琐——它并不只是"把数据灌进去让模型学",而是要逐个环节处理数据格式、标签掩码、损失计算、显存管理。任何一个环节做错,训练都不会报错,但模型上线后就是不灵

整个流水线可以抽象成六步:

我们逐步拆开看。

第 1 步:加载 base 模型和 tokenizer

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "Qwen/Qwen2.5-1.5B"  # base 版本,不是 -Instruct

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

# 关键:如果 tokenizer 没有 pad_token,手动指定(用 eos_token 复用)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

新手最常见的错:加载了 -Instruct 版本继续做 SFT。SFT 的起点应该是 base 模型(未经过任何指令微调),不然你是在"二次微调已经对齐过的模型",极易导致灾难性遗忘(catastrophic forgetting)。

第 2 步:应用 chat template

不同模型家族用的聊天模板(chat template)完全不同,手动拼接几乎一定会出错。正确做法是让 tokenizer.apply_chat_template() 帮你做。

以 Qwen2.5 的 ChatML 为例

messages = [
    {"role": "system", "content": "你是一个北大软微的课程助教。"},
    {"role": "user", "content": "请用三句话介绍研究生手册的迟到规定。"},
    {"role": "assistant", "content": "根据手册第 3 章 ..."}
]

text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,                # 先拿到字符串,便于调试
    add_generation_prompt=False    # 训练时不加生成提示
)
print(text)

输出的 ChatML 格式:

<|im_start|>system
你是一个北大软微的课程助教。<|im_end|>
<|im_start|>user
请用三句话介绍研究生手册的迟到规定。<|im_end|>
<|im_start|>assistant
根据手册第 3 章 ...<|im_end|>

关键原则:永远使用 tokenizer.apply_chat_template()永远不要手动拼接 <|im_start|>。不同模型(Qwen / Llama / Gemma / DeepSeek)的特殊 token ID 完全不同,手动拼接不会报错但会静默地训坏模型

第 3 步:Tokenize 与掩码(masking)

这是整个 SFT 最关键的一步

为什么需要掩码

标准的语言建模会在所有 token 上计算交叉熵损失——但 SFT 里,我们只想让模型学会"怎么回答",不是学会"怎么复述用户的问题"。

考虑一条训练样本:system + user + assistant。如果不做掩码:

  • 模型会花一半梯度学"怎么生成用户的问题"——这没用,用户问题在推理时是给定的输入
  • 更糟:模型可能把自己的回答风格"污染到 user 角色"上去

掩码损失公式

SFT 的训练目标是只在 assistant 角色的 token 上计算交叉熵:

LSFT(θ)=1At=1TmtlogPθ(xtx<t)\mathcal{L}_{\text{SFT}}(\theta) = -\frac{1}{|\mathcal{A}|} \sum_{t=1}^{T} m_t \cdot \log P_\theta(x_t \mid x_{<t})

其中:

  • xtx_t 是第 tt 个 token
  • x<tx_{<t} 是前 t1t-1 个 token
  • mt{0,1}m_t \in \{0, 1\} 是二值掩码:assistant token 为 1,其他为 0
  • A={t:mt=1}\mathcal{A} = \{ t : m_t = 1 \} 是所有 assistant token 的位置集合
  • A|\mathcal{A}| 做归一化,避免长 assistant 样本主导损失

掩码示意

角色内容mask
system你是一个北大软微的课程助教。0 0 0 ...
user请用三句话介绍 ...0 0 0 ...
assistant根据手册第 3 章 ...1 1 1 ...

在 TRL 里的实现

Hugging Face TRL 的 SFTTrainer 已经内置了 chat template 和掩码处理:

from trl import SFTTrainer, SFTConfig
from datasets import Dataset

# 准备数据——格式是 messages 列表
train_data = Dataset.from_list([
    {"messages": [
        {"role": "user", "content": "什么是 LoRA?"},
        {"role": "assistant", "content": "LoRA 是一种低秩适配方法 ..."}
    ]},
    # ... 更多样本
])

sft_config = SFTConfig(
    output_dir="./qwen2.5-sft",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    max_length=2048,
    packing=False,            # 是否把多条样本拼接到一条里(省显存但增复杂度)
    assistant_only_loss=True, # 仅在 assistant token 上算损失(默认 True)
    bf16=True,
)

trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=train_data,
    processing_class=tokenizer,
)

trainer.train()

第 4 步:DataCollator(padding / packing)

一个 batch 里不同样本长度不同,需要对齐。两种策略:

策略做法优缺点
Padding把所有样本补到该 batch 内最长样本的长度简单、兼容性好;浪费算力(pad 位置不算梯度但仍占显存)
Packing把多条短样本拼接到一条长序列里,用 attention mask 区分算力利用率高;实现复杂,需要"block-diagonal attention"

TRL 里开启 packing 只需 packing=True短样本多、长样本少的数据集开 packing 能把吞吐量提升 1.5-3×。

第 5 步:Trainer 训练循环

一次 SFT 的典型超参:

sft_config = SFTConfig(
    output_dir="./output",

    # ===== 训练规模 =====
    num_train_epochs=3,                  # SFT 通常 1-3 epoch 足够,过多会过拟合
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,       # 有效 batch size = 4 × 4 × GPU 数

    # ===== 优化器 =====
    learning_rate=2e-4,                  # LoRA: 1e-4 到 5e-4 / 全参: 1e-5 到 5e-5
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.01,

    # ===== 精度 / 显存 =====
    bf16=True,                           # A100/H100 用 bf16,更稳
    gradient_checkpointing=True,         # 省显存(约 30-50%),换 ~20% 速度
    gradient_checkpointing_kwargs={"use_reentrant": False},

    # ===== 日志 =====
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,

    # ===== 其他 =====
    max_length=2048,
    packing=False,
    seed=42,
    report_to="none",                    # 或 "wandb" / "tensorboard"
)

第 6 步:监控损失

一条健康的 SFT 损失曲线长这样:

  • 前 50 步:损失从 ~2.5 快速降到 ~1.5
  • 中段:缓慢稳定下降到 0.8-1.2
  • 末段:平稳或小幅波动

异常信号

  • 损失不下降:lr 太小 / 数据格式错 / 掩码全为 0
  • 损失快速崩到 0:掩码错了(所有 token 都算损失 + 模型过拟合)
  • 损失震荡:lr 太大 / batch 太小 / 数据噪声大
  • eval_loss 远高于 train_loss:过拟合,缩小 epoch 或增大 dropout

常见 Bug 清单

症状:模型在 padding 位置生成垃圾 token。

原因labels 里的 pad_token_id 没被设为 -100,导致模型被"教"去预测 pad。

修复

# 确保 DataCollator 把 pad 位置的 label 设为 -100
from transformers import DataCollatorForLanguageModeling

collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8,
)
# TRL 的 SFTTrainer 默认会处理这个,手写 Trainer 时容易漏

症状:模型生成后停不下来,一直重复直到 max_new_tokens

原因:训练数据里的每条 assistant 回答末尾没加 EOS token,模型永远学不到"何时停"。

修复:检查 apply_chat_template 后的字符串末尾是否有 <|im_end|><|endoftext|>;必要时手动追加:

if not text.endswith(tokenizer.eos_token):
    text = text + tokenizer.eos_token

症状:模型微调时能收敛,但推理时输出质量奇差。

原因:训练时用的是 ChatML,推理时用的是 Llama 格式(或反之)。必须完全一致

修复:推理阶段永远用同一个 tokenizer + 同一个 apply_chat_template(messages, add_generation_prompt=True)

症状:Trainer 抛 KeyError: 'messages' 或训练结果异常。

原因:TRL 需要特定列名(messagestext);数据集里有额外的列(instructionoutput)也会干扰 collator。

修复:显式删除多余列:

cols_to_keep = ["messages"]
cols_to_remove = [c for c in dataset.column_names if c not in cols_to_keep]
dataset = dataset.remove_columns(cols_to_remove)

本节小结

概念要点
起点base 模型,不是 -Instruct 版本
格式apply_chat_template永不手动拼接
掩码只在 assistant token 上算损失,公式 L=tAlogPθ(xtx<t)\mathcal{L} = -\sum_{t \in \mathcal{A}} \log P_\theta(x_t \| x_{<t})
Padding vs. Packing短样本多 → packing 省 1.5-3× 算力
典型超参lr=2e-4(LoRA)/ 2e-5(全参),1-3 epoch,bf16,grad_ckpt
健康损失从 ~2.5 降到 0.8-1.2,平滑无震荡
必查 4 Bugpad mask、EOS 缺失、template 不一致、列名漂移