6.2 监督微调(SFT)
完整 SFT 流水线:加载 base → chat template → tokenize + masking → DataCollator → Trainer → 监控
SFT 完整流水线
监督微调(Supervised Fine-Tuning, SFT)的工程实现比名字听起来要繁琐——它并不只是"把数据灌进去让模型学",而是要逐个环节处理数据格式、标签掩码、损失计算、显存管理。任何一个环节做错,训练都不会报错,但模型上线后就是不灵。
整个流水线可以抽象成六步:
我们逐步拆开看。
第 1 步:加载 base 模型和 tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "Qwen/Qwen2.5-1.5B" # base 版本,不是 -Instruct
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# 关键:如果 tokenizer 没有 pad_token,手动指定(用 eos_token 复用)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token新手最常见的错:加载了 -Instruct 版本继续做 SFT。SFT 的起点应该是 base 模型(未经过任何指令微调),不然你是在"二次微调已经对齐过的模型",极易导致灾难性遗忘(catastrophic forgetting)。
第 2 步:应用 chat template
不同模型家族用的聊天模板(chat template)完全不同,手动拼接几乎一定会出错。正确做法是让 tokenizer.apply_chat_template() 帮你做。
以 Qwen2.5 的 ChatML 为例
messages = [
{"role": "system", "content": "你是一个北大软微的课程助教。"},
{"role": "user", "content": "请用三句话介绍研究生手册的迟到规定。"},
{"role": "assistant", "content": "根据手册第 3 章 ..."}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False, # 先拿到字符串,便于调试
add_generation_prompt=False # 训练时不加生成提示
)
print(text)输出的 ChatML 格式:
<|im_start|>system
你是一个北大软微的课程助教。<|im_end|>
<|im_start|>user
请用三句话介绍研究生手册的迟到规定。<|im_end|>
<|im_start|>assistant
根据手册第 3 章 ...<|im_end|>关键原则:永远使用 tokenizer.apply_chat_template(),永远不要手动拼接 <|im_start|>。不同模型(Qwen / Llama / Gemma / DeepSeek)的特殊 token ID 完全不同,手动拼接不会报错但会静默地训坏模型。
第 3 步:Tokenize 与掩码(masking)
这是整个 SFT 最关键的一步。
为什么需要掩码
标准的语言建模会在所有 token 上计算交叉熵损失——但 SFT 里,我们只想让模型学会"怎么回答",不是学会"怎么复述用户的问题"。
考虑一条训练样本:system + user + assistant。如果不做掩码:
- 模型会花一半梯度学"怎么生成用户的问题"——这没用,用户问题在推理时是给定的输入
- 更糟:模型可能把自己的回答风格"污染到 user 角色"上去
掩码损失公式
SFT 的训练目标是只在 assistant 角色的 token 上计算交叉熵:
其中:
- 是第 个 token
- 是前 个 token
- 是二值掩码:assistant token 为 1,其他为 0
- 是所有 assistant token 的位置集合
- 做归一化,避免长 assistant 样本主导损失
掩码示意
| 角色 | 内容 | mask |
|---|---|---|
| system | 你是一个北大软微的课程助教。 | 0 0 0 ... |
| user | 请用三句话介绍 ... | 0 0 0 ... |
| assistant | 根据手册第 3 章 ... | 1 1 1 ... |
在 TRL 里的实现
Hugging Face TRL 的 SFTTrainer 已经内置了 chat template 和掩码处理:
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
# 准备数据——格式是 messages 列表
train_data = Dataset.from_list([
{"messages": [
{"role": "user", "content": "什么是 LoRA?"},
{"role": "assistant", "content": "LoRA 是一种低秩适配方法 ..."}
]},
# ... 更多样本
])
sft_config = SFTConfig(
output_dir="./qwen2.5-sft",
num_train_epochs=3,
per_device_train_batch_size=4,
max_length=2048,
packing=False, # 是否把多条样本拼接到一条里(省显存但增复杂度)
assistant_only_loss=True, # 仅在 assistant token 上算损失(默认 True)
bf16=True,
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=train_data,
processing_class=tokenizer,
)
trainer.train()第 4 步:DataCollator(padding / packing)
一个 batch 里不同样本长度不同,需要对齐。两种策略:
| 策略 | 做法 | 优缺点 |
|---|---|---|
| Padding | 把所有样本补到该 batch 内最长样本的长度 | 简单、兼容性好;浪费算力(pad 位置不算梯度但仍占显存) |
| Packing | 把多条短样本拼接到一条长序列里,用 attention mask 区分 | 算力利用率高;实现复杂,需要"block-diagonal attention" |
TRL 里开启 packing 只需 packing=True。短样本多、长样本少的数据集开 packing 能把吞吐量提升 1.5-3×。
第 5 步:Trainer 训练循环
一次 SFT 的典型超参:
sft_config = SFTConfig(
output_dir="./output",
# ===== 训练规模 =====
num_train_epochs=3, # SFT 通常 1-3 epoch 足够,过多会过拟合
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # 有效 batch size = 4 × 4 × GPU 数
# ===== 优化器 =====
learning_rate=2e-4, # LoRA: 1e-4 到 5e-4 / 全参: 1e-5 到 5e-5
lr_scheduler_type="cosine",
warmup_ratio=0.1,
weight_decay=0.01,
# ===== 精度 / 显存 =====
bf16=True, # A100/H100 用 bf16,更稳
gradient_checkpointing=True, # 省显存(约 30-50%),换 ~20% 速度
gradient_checkpointing_kwargs={"use_reentrant": False},
# ===== 日志 =====
logging_steps=10,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=200,
save_total_limit=2,
# ===== 其他 =====
max_length=2048,
packing=False,
seed=42,
report_to="none", # 或 "wandb" / "tensorboard"
)第 6 步:监控损失
一条健康的 SFT 损失曲线长这样:
- 前 50 步:损失从 ~2.5 快速降到 ~1.5
- 中段:缓慢稳定下降到 0.8-1.2
- 末段:平稳或小幅波动
异常信号:
- 损失不下降:lr 太小 / 数据格式错 / 掩码全为 0
- 损失快速崩到 0:掩码错了(所有 token 都算损失 + 模型过拟合)
- 损失震荡:lr 太大 / batch 太小 / 数据噪声大
- eval_loss 远高于 train_loss:过拟合,缩小 epoch 或增大 dropout
常见 Bug 清单
症状:模型在 padding 位置生成垃圾 token。
原因:labels 里的 pad_token_id 没被设为 -100,导致模型被"教"去预测 pad。
修复:
# 确保 DataCollator 把 pad 位置的 label 设为 -100
from transformers import DataCollatorForLanguageModeling
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=8,
)
# TRL 的 SFTTrainer 默认会处理这个,手写 Trainer 时容易漏症状:模型生成后停不下来,一直重复直到 max_new_tokens。
原因:训练数据里的每条 assistant 回答末尾没加 EOS token,模型永远学不到"何时停"。
修复:检查 apply_chat_template 后的字符串末尾是否有 <|im_end|> 或 <|endoftext|>;必要时手动追加:
if not text.endswith(tokenizer.eos_token):
text = text + tokenizer.eos_token症状:模型微调时能收敛,但推理时输出质量奇差。
原因:训练时用的是 ChatML,推理时用的是 Llama 格式(或反之)。必须完全一致。
修复:推理阶段永远用同一个 tokenizer + 同一个 apply_chat_template(messages, add_generation_prompt=True)。
症状:Trainer 抛 KeyError: 'messages' 或训练结果异常。
原因:TRL 需要特定列名(messages 或 text);数据集里有额外的列(instruction、output)也会干扰 collator。
修复:显式删除多余列:
cols_to_keep = ["messages"]
cols_to_remove = [c for c in dataset.column_names if c not in cols_to_keep]
dataset = dataset.remove_columns(cols_to_remove)本节小结
| 概念 | 要点 |
|---|---|
| 起点 | base 模型,不是 -Instruct 版本 |
| 格式 | 用 apply_chat_template,永不手动拼接 |
| 掩码 | 只在 assistant token 上算损失,公式 |
| Padding vs. Packing | 短样本多 → packing 省 1.5-3× 算力 |
| 典型超参 | lr=2e-4(LoRA)/ 2e-5(全参),1-3 epoch,bf16,grad_ckpt |
| 健康损失 | 从 ~2.5 降到 0.8-1.2,平滑无震荡 |
| 必查 4 Bug | pad mask、EOS 缺失、template 不一致、列名漂移 |