实验3:用 DSPy 编译一个中文问答 pipeline
在中文 QA 任务上对比手写 prompt 与 BootstrapFewShot 编译后的效果,理解声明式优化的威力
实验概述
本实验将在一个中文问答任务上完整走一遍 DSPy 工作流,亲手验证"编译后的 pipeline 能否超过手写 prompt"。我们会先搭一个最简单的 dspy.Predict,再换成 dspy.ChainOfThought,最后用 BootstrapFewShot 编译——并在同一个评估集上对比三者的准确率。
| 项目 | 详情 |
|---|---|
| 任务 | 中文常识 QA(基于 CMRC2018 或自建小数据集) |
| 后端 LM | gpt-4o-mini 或本地 Qwen3-7B-Instruct(via vLLM / Ollama) |
| DSPy 版本 | dspy-ai >= 2.5 |
| 评估指标 | Exact Match(EM)+ LLM-as-Judge 的软匹配 |
| 预计时间 | 约 60 分钟(含编译) |
API 成本预警:若使用 OpenAI 后端,编译阶段会触发 50–200 次 LM 调用(取决于训练集大小和 Bootstrap 轮数)。建议先用 10–20 条样本做 smoke test,再放大到完整训练集。
实验步骤
步骤 1:环境准备与数据加载(10 分钟)
先安装 DSPy 并配置 LM 后端。
pip install dspy-ai datasets jiebaimport dspy
import random
from datasets import load_dataset
# ===== 1. 配置 LM 后端 =====
# 方案 A: OpenAI
lm = dspy.LM(
model="openai/gpt-4o-mini",
api_key="sk-...", # 你的 key
max_tokens=512,
)
dspy.configure(lm=lm)
# 方案 B: 本地 vLLM (如果在 AutoDL 上)
# lm = dspy.LM(
# model="openai/Qwen3-7B-Instruct",
# api_base="http://localhost:8000/v1",
# api_key="EMPTY",
# )
# ===== 2. 加载中文 QA 数据 =====
# 用 CMRC2018 的片段做简化版 QA
raw = load_dataset("cmrc2018", split="train[:300]", trust_remote_code=True)
# 每条样本保留 (question, context, answer) 三元组
records = []
for x in raw:
if x["answers"]["text"]:
records.append({
"question": x["question"],
"context": x["context"][:500], # 裁到 500 字以控制 token
"answer": x["answers"]["text"][0],
})
random.seed(42)
random.shuffle(records)
train_data = [dspy.Example(**r).with_inputs("question", "context") for r in records[:60]]
dev_data = [dspy.Example(**r).with_inputs("question", "context") for r in records[60:110]]
test_data = [dspy.Example(**r).with_inputs("question", "context") for r in records[110:160]]
print(f"Train: {len(train_data)}, Dev: {len(dev_data)}, Test: {len(test_data)}")
print(f"示例: {train_data[0].question} -> {train_data[0].answer}")步骤 2:定义 Signature(10 分钟)
Signature 是 DSPy 的核心接口——声明输入输出,不写 prompt。
class ChineseQA(dspy.Signature):
"""根据给定的中文上下文回答问题.答案应简洁,优先直接从上下文中提取."""
context = dspy.InputField(desc="包含答案的中文段落")
question = dspy.InputField(desc="需要回答的中文问题")
answer = dspy.OutputField(desc="简短的答案,通常是 1–10 个字")对比一下手写 prompt 版本作为 baseline(我们稍后会用它作对比组):
HANDCRAFTED_PROMPT = """你是一个中文问答助手.请根据下面提供的上下文回答问题.
答案应该简洁,直接从上下文中提取相关信息.
上下文: {context}
问题: {question}
答案:"""
def handcrafted_predict(context, question):
prompt = HANDCRAFTED_PROMPT.format(context=context, question=question)
# 直接调用 lm(绕过 DSPy 的抽象,纯手工)
return lm(prompt)[0].strip()步骤 3:定义 Module 与 Metric(10 分钟)
我们比较两种 Module:
# Variant 1: 最简单的 Predict
simple_qa = dspy.Predict(ChineseQA)
# Variant 2: ChainOfThought
cot_qa = dspy.ChainOfThought(ChineseQA)
# 测试一下 Variant 1
pred = simple_qa(
context="北京大学成立于 1898 年,原名京师大学堂,是中国近代第一所国立大学.",
question="北京大学的前身是什么?"
)
print(pred.answer) # 应该输出 "京师大学堂"定义评估指标——中文 QA 中直接 EM 往往偏严(标点、同义词),我们做一个宽松的"包含匹配":
import re
def normalize_zh(text):
"""去标点、去空白,便于模糊匹配."""
return re.sub(r"[\s\p{P}]", "", text.strip())
def answer_match(example, pred, trace=None):
"""判断预测答案是否包含金标答案(或反之)."""
gold = normalize_zh(example.answer)
pred_ans = normalize_zh(pred.answer)
if not gold or not pred_ans:
return False
return (gold in pred_ans) or (pred_ans in gold)步骤 4:Zero-shot 评估基线(10 分钟)
在动用编译器之前,先记录三个 baseline 的未编译分数:
from dspy.evaluate import Evaluate
evaluator = Evaluate(
devset=dev_data,
metric=answer_match,
num_threads=4,
display_progress=True,
)
# Baseline 1: 手写 prompt
def run_handcrafted():
correct = 0
for ex in dev_data:
ans = handcrafted_predict(ex.context, ex.question)
pred = dspy.Prediction(answer=ans)
if answer_match(ex, pred):
correct += 1
return correct / len(dev_data)
print("===== Zero-shot Baselines =====")
print(f"手写 prompt: {run_handcrafted():.2%}")
print(f"dspy.Predict (0-shot): {evaluator(simple_qa):.2%}")
print(f"dspy.ChainOfThought (0-shot): {evaluator(cot_qa):.2%}")预期观察:三者在 Dev 集上应该都在 50–70% 区间;手写 prompt 和 dspy.Predict 接近,ChainOfThought 略高(因为中间推理帮助了抽取)。
步骤 5:使用 BootstrapFewShot 编译(15 分钟)
现在召唤编译器——让 DSPy 在训练集上自动生成 Few-shot 示例:
from dspy.teleprompt import BootstrapFewShot
# 配置 Teleprompter
teleprompter = BootstrapFewShot(
metric=answer_match,
max_bootstrapped_demos=4, # 最多注入 4 个自举生成的示例
max_labeled_demos=4, # 再从训练集直接挑 4 个标注示例
max_rounds=1, # 一轮编译足够
)
# 编译 CoT 版本(编译会触发约 50–100 次 LM 调用)
print("开始编译 ChainOfThought...")
compiled_cot = teleprompter.compile(
student=dspy.ChainOfThought(ChineseQA),
trainset=train_data,
)
print("编译完成!")
# 看看编译器生成了哪些 Few-shot 示例
for i, demo in enumerate(compiled_cot.predict.demos[:2]):
print(f"\n--- Demo {i} ---")
print(f"Question: {demo.question[:80]}")
print(f"Rationale: {demo.rationale[:150] if hasattr(demo, 'rationale') else 'N/A'}")
print(f"Answer: {demo.answer}")
# 在 dev 集上评估编译后的程序
compiled_score = evaluator(compiled_cot)
print(f"\n编译后 ChainOfThought: {compiled_score:.2%}")预期观察:编译后应该相比未编译的 CoT 再涨 5–15 个百分点。DSPy 论文在 GSM8K 上从 72% → 78%,在中文 QA 上数量级类似。
步骤 6:Test 集评估与错误分析(10 分钟)
前面所有比较都在 Dev 集上——最终报数要用 Test 集以避免过拟合。
test_evaluator = Evaluate(devset=test_data, metric=answer_match, num_threads=4)
results = {
"手写 prompt": run_handcrafted_on(test_data), # 你需要改成 test 版本
"dspy.Predict (0-shot)": test_evaluator(simple_qa),
"dspy.CoT (0-shot)": test_evaluator(cot_qa),
"dspy.CoT (compiled)": test_evaluator(compiled_cot),
}
print("\n===== Test Set Results =====")
for name, score in results.items():
print(f"{name:<32} {score:.2%}")错误分析——找出编译后仍然错的样本,归类:
errors = []
for ex in test_data:
pred = compiled_cot(context=ex.context, question=ex.question)
if not answer_match(ex, pred):
errors.append({
"question": ex.question,
"gold": ex.answer,
"pred": pred.answer,
"rationale": getattr(pred, "rationale", ""),
})
print(f"\n错误数: {len(errors)} / {len(test_data)}")
for e in errors[:5]:
print(f"\n问: {e['question']}")
print(f" 金标: {e['gold']}")
print(f" 预测: {e['pred']}")
print(f" 推理: {e['rationale'][:120]}")常见错误类型(可在报告中分类):
- 同义改写——"1898 年" vs. "清光绪二十四年"(本质正确,指标误判)
- 范围过宽——金标 "京师大学堂",预测 "京师大学堂,是中国第一所..."
- 真错误——推理链偏离上下文,产生幻觉
加分项:用 MIPRO 做多阶段联合优化
如果时间充裕,可以把 RAG 结构完整搭起来(检索 + CoT),并用 MIPRO 做联合编译:
from dspy.teleprompt import MIPROv2
class RetrievalQA(dspy.Module):
def __init__(self):
self.retrieve = dspy.Retrieve(k=3) # 需要额外配置检索器
self.answer = dspy.ChainOfThought(ChineseQA)
def forward(self, question):
passages = self.retrieve(question).passages
return self.answer(context="\n".join(passages), question=question)
mipro = MIPROv2(metric=answer_match, auto="light")
compiled_rag = mipro.compile(RetrievalQA(), trainset=train_data)对比 BootstrapFewShot 与 MIPROv2 在同一 Test 集上的差距——MIPRO 通常在多阶段 pipeline 上有 +3–8% 额外提升,但编译耗时更长。
交付物清单
完成实验后,请提交以下内容:
- 四行对比表:手写 prompt / Predict / CoT / CoT-compiled 在 Test 集上的准确率
- 编译器生成的 2 个 Demo 示例截图或打印
- 错误分析:至少 5 个失败案例,每个标注错误类型
- 1 页书面分析,回答以下问题:
- 编译后的 CoT 相比手写 prompt 主要赢在哪里?(是格式、是推理、还是示例选择?)
- 如果把训练集从 60 条增加到 200 条,你预期哪个方法会进一步提升,哪个会饱和?
- 观察编译器生成的 Few-shot 示例——它们和人工挑选的会是一样的吗?如果不一样,哪里不同?
- (可选)对比
BootstrapFewShot与MIPROv2的结果与时间
时间预估:
- 环境与数据:~10 分钟
- Signature / Module / Metric:~20 分钟
- Zero-shot baseline:~10 分钟
- 编译 + Test 评估:~20 分钟
- 错误分析与书面报告:~30 分钟(课后)
- 总计堂上时间:约 60 分钟