6.5 规模与显存
显存估算公式、数据量与 epoch 建议、batch size 与 grad accumulation、多卡简介
训练显存的四个来源
训练一个大模型时,显存不是只放模型参数,而是由四个部分构成:
每一部分的大小取决于:参数量、精度、优化器、batch size、sequence length。我们逐个拆开。
参数与梯度
参数(weights)和梯度(gradients)的显存是最直观的一部分:
| 精度 | 每参数占用 |
|---|---|
| FP32 | 4 bytes |
| FP16 / BF16 | 2 bytes |
| INT8 | 1 byte |
| INT4 / NF4 | 0.5 byte |
梯度的精度一般和计算精度一致(BF16 训练时,梯度也是 BF16)。
LoRA 的关键好处: 只有全量的 0.5%-5%,梯度显存几乎可以忽略。
优化器状态(最大头)
AdamW 对每个可训练参数维护:
- 一阶动量 (FP32):4 bytes/参数
- 二阶动量 (FP32):4 bytes/参数
即使用的是 BF16 训练,动量仍然保持 FP32(业界标准,否则极不稳定)。
这就是全参微调显存爆炸的根源:8B 参数 × 8 bytes = 64 GB,仅优化器状态就干掉一张 A100-80G。
省显存的优化器:
| 优化器 | 每参数 | 代价 |
|---|---|---|
| AdamW(FP32 状态) | 8 bytes | 基线 |
AdamW 8-bit(bitsandbytes) | 2 bytes | 轻微收敛速度下降 |
| AdamW Paged(QLoRA 标配) | 2 bytes + CPU | 高显存利用率但稍慢 |
| Adafactor | 4 bytes | 对大模型有时不如 AdamW |
激活值(和 batch × seq 有关)
激活值(activations)是反向传播需要保存的中间张量。它的大小正比于 batch size × sequence length × hidden dim × 层数。
粗略估算公式(Transformer):
其中 是一个常数(10-20,和具体实现有关)。举例:7B 模型,batch=4,seq=2048,BF16——激活约 8-12 GB。
省激活的神器是 gradient checkpointing:
- 做法:只保存部分层的激活,反向传播时重新前向计算丢弃的层
- 代价:算力增加约 20-35%
- 收益:激活显存降低 约 50-70%
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)KV Cache(推理时的大户)
训练期用不到 KV cache(因为每步都要重算所有位置的 attention)。推理期它会占很大的显存:
其中 2 来自 K 和 V 两份。对 7B 模型、batch=1、seq=4096,KV cache 约 2 GB。做 SFT 时要在训练结束后关闭 use_cache 以节约显存:
model.config.use_cache = False # 训练时
# 训练完成、准备推理时
model.config.use_cache = True一个可用的粗略估算
整合以上,7B 模型 + LoRA + batch=4 + seq=2048 的显存估算:
| 组件 | 计算 | 显存 |
|---|---|---|
| 模型参数(BF16) | 7B × 2 | 14 GB |
| 可训参数(LoRA, r=32) | ~40M × 2 | 0.08 GB |
| 梯度(LoRA) | ~40M × 2 | 0.08 GB |
| 优化器状态(AdamW FP32) | ~40M × 8 | 0.32 GB |
| 激活(gradient_ckpt=True) | ~ | 5-8 GB |
| CUDA 内核 + 杂项 | ~ | 2 GB |
| 合计 | ~22-25 GB |
一张 A100-40G 轻松跑。如果换 QLoRA(4-bit),模型权重降到 4-5 GB,总显存掉到 ~15 GB,RTX 4090 / 3090 都能跑。
数据量与 epoch 建议
SFT 不是预训练,过大的数据量和 epoch 反而会损害效果(模型会被训"死"——失去原有的多样性和推理能力)。
| 数据量 | 建议 epoch | 备注 |
|---|---|---|
| < 1K 条 | 3-5 | 典型场景:风格适配、格式约束;LIMA 实验 |
| 1K-10K 条 | 2-3 | 最常见场景:领域助手、客服机器人 |
| 10K-100K 条 | 1-2 | 指令跟随类通用数据 |
| > 100K 条 | 1 | 接近预训练,做多 epoch 易过拟合 |
停止训练的信号:
- eval_loss 开始上升,而 train_loss 继续下降 → 过拟合
- train_loss 持续下降但生成质量反而变差(用 held-out 评估集验证)
- 输出的多样性显著降低(模型开始重复相同表达)
Batch size 与梯度累积
物理 batch size(一次前向传播的样本数)受显存限制,有效 batch size(梯度更新使用的样本数)决定训练稳定性。通过梯度累积(gradient accumulation)解耦两者:
经验有效 batch size:
| 场景 | 有效 batch size |
|---|---|
| LoRA 小规模 SFT(1-10K 样本) | 16-32 |
| LoRA 中规模 SFT(10-100K) | 32-64 |
| 全参微调 | 64-128 |
| 预训练 | 512-1024+ |
# 单卡 A100-40G 上 Qwen2.5-7B 的 LoRA 典型配置
per_device_train_batch_size=4, # 物理 batch
gradient_accumulation_steps=8, # 累积 8 步
# → 有效 batch size = 4 × 8 = 32学习率 × batch × 数据量
一个经验规律(平方根规则):
即 batch 增大 4 倍,lr 大约增大 2 倍(而不是 4 倍)。这对多卡训练尤其重要——从 1 卡加到 4 卡时别忘了调整 lr。
LoRA 的 lr 基准:
- 小数据(<10K):
1e-4到2e-4 - 中数据(10-100K):
2e-4到5e-4 - warmup 比例:10%(第一个 epoch 的前 10% 步数线性升温)
训练时长预估
以 Qwen2.5-7B + QLoRA + A100-40G 为例的粗略估算:
| 数据量 | seq_len | batch | 时长 |
|---|---|---|---|
| 1K 条 | 2048 | 4 | ~15 分钟/epoch |
| 10K 条 | 2048 | 4 | ~2 小时/epoch |
| 100K 条 | 2048 | 4 | ~20 小时/epoch |
实测速度(samples/sec)受以下因素影响很大:
- packing 是否开启(开了 1.5-3×)
- gradient_checkpointing(关了快 20-30%,但显存吃紧)
- unsloth(约 2× 加速)
- Flash Attention 2(开了快 10-30%)
何时需要多卡
LoRA/QLoRA 场景下,大部分 7B-14B 的微调可以在单卡完成。真正需要多卡的场景:
| 场景 | 推荐方案 |
|---|---|
| 数据量 > 50K 且要 > 1 epoch | 多卡 DDP(数据并行) |
| 全参数微调 7B+ | DeepSpeed ZeRO-2 或 ZeRO-3 |
| 全参数微调 32B+ | DeepSpeed ZeRO-3 + CPU offload |
| 超大模型 70B+ 的 LoRA | FSDP(PyTorch 原生) |
DDP / FSDP / DeepSpeed 简介
Distributed Data Parallel:每张卡完整复制模型和优化器状态,只对梯度做 all-reduce 同步。
- 优点:简单,速度快
- 缺点:每张卡都要放得下完整模型
- 适用:LoRA / QLoRA 的多卡加速(省的是时间,不是显存)
ZeRO(Zero Redundancy Optimizer):把优化器状态、梯度、参数分片到多张卡上。
| 阶段 | 分片对象 | 显存节省 |
|---|---|---|
| ZeRO-1 | 优化器状态 | ~4× |
| ZeRO-2 | 优化器 + 梯度 | ~8× |
| ZeRO-3 | 优化器 + 梯度 + 参数 | ~N×(N = GPU 数) |
- 适用:全参数微调 7B+
- 代价:通信开销增大,小 batch 时速度下降明显
Fully Sharded Data Parallel:PyTorch 2.0+ 原生分布式方案,和 ZeRO-3 思路相似。
- 优点:原生 PyTorch 集成
- 缺点:生态不如 DeepSpeed 成熟
- 适用:大模型 LoRA(70B+)、全参数微调(32B+)
深入阅读:分布式训练的数学细节(gradient synchronization、parameter sharding、activation checkpointing 的 tradeoff)在姊妹课程 大语言模型后训练实践 中有单独一讲。本讲到此为止,够你用单卡 / 双卡跑完绝大多数 SFT 场景。
一份"我该怎么配"速查表
model: Qwen/Qwen2.5-7B
method: LoRA (BF16, r=32)
per_device_batch: 4
grad_accum: 8
seq_len: 2048
packing: true
gradient_checkpointing: true
# → 有效 batch 32, 显存 ~32 GBmodel: Qwen/Qwen2.5-7B
method: QLoRA (NF4, r=32)
per_device_batch: 2
grad_accum: 16
seq_len: 2048
packing: true
gradient_checkpointing: true
# → 有效 batch 32, 显存 ~18 GBmodel: Qwen/Qwen2.5-1.5B
method: QLoRA (NF4, r=16) + unsloth
per_device_batch: 2
grad_accum: 8
seq_len: 1024
packing: true
gradient_checkpointing: true
# → 有效 batch 16, 显存 ~12 GBmodel: Qwen/Qwen2.5-32B
method: LoRA (BF16, r=64)
distributed: DeepSpeed ZeRO-2
per_device_batch: 2
grad_accum: 4
num_gpus: 4
seq_len: 2048
# → 有效 batch 32, 总显存 ~4 × 40 GB本节小结
| 概念 | 要点 |
|---|---|
| 显存四部分 | 参数 + 梯度 + 优化器 + 激活 |
| 优化器是大头 | AdamW FP32 = 8 bytes/参数,是参数本身的 4 倍 |
| LoRA 省的是后三部分 | 梯度、优化器、激活都只作用于 ~1% 的可训参数 |
| Gradient checkpointing | 省 50-70% 激活,代价是慢 20-35% |
| 有效 batch size | per_device × grad_accum × num_gpus |
| SFT 数据量 | 1K-10K 条最常见,2-3 epoch 够了 |
| lr 经验值 | LoRA 2e-4,全参 2e-5,batch 翻倍 lr 倍 |
| 多卡入口 | DDP(LoRA 加速)/ ZeRO(省显存)/ FSDP(超大模型) |