人工智能实践(语言智能)
第6讲:大模型微调

6.5 规模与显存

显存估算公式、数据量与 epoch 建议、batch size 与 grad accumulation、多卡简介

训练显存的四个来源

训练一个大模型时,显存不是只放模型参数,而是由四个部分构成:

Memtotal=Memparams+Memgrads+Memoptim+Memact\text{Mem}_{\text{total}} = \text{Mem}_{\text{params}} + \text{Mem}_{\text{grads}} + \text{Mem}_{\text{optim}} + \text{Mem}_{\text{act}}

每一部分的大小取决于:参数量、精度、优化器、batch size、sequence length。我们逐个拆开。

参数与梯度

参数(weights)和梯度(gradients)的显存是最直观的一部分

精度每参数占用
FP324 bytes
FP16 / BF162 bytes
INT81 byte
INT4 / NF40.5 byte

梯度的精度一般和计算精度一致(BF16 训练时,梯度也是 BF16)。

Memparams=NparamsBparamMemgrads=NtrainableBgrad\text{Mem}_{\text{params}} = N_{\text{params}} \cdot B_{\text{param}} \qquad \text{Mem}_{\text{grads}} = N_{\text{trainable}} \cdot B_{\text{grad}}

LoRA 的关键好处NtrainableN_{\text{trainable}} 只有全量的 0.5%-5%,梯度显存几乎可以忽略。

优化器状态(最大头)

AdamW 对每个可训练参数维护:

  • 一阶动量 mm(FP32):4 bytes/参数
  • 二阶动量 vv(FP32):4 bytes/参数

即使用的是 BF16 训练,动量仍然保持 FP32(业界标准,否则极不稳定)。

Memoptim-AdamW=Ntrainable8 bytes\text{Mem}_{\text{optim-AdamW}} = N_{\text{trainable}} \cdot 8 \text{ bytes}

这就是全参微调显存爆炸的根源:8B 参数 × 8 bytes = 64 GB,仅优化器状态就干掉一张 A100-80G。

省显存的优化器

优化器每参数代价
AdamW(FP32 状态)8 bytes基线
AdamW 8-bit(bitsandbytes2 bytes轻微收敛速度下降
AdamW Paged(QLoRA 标配)2 bytes + CPU高显存利用率但稍慢
Adafactor4 bytes对大模型有时不如 AdamW

激活值(和 batch × seq 有关)

激活值(activations)是反向传播需要保存的中间张量。它的大小正比于 batch size × sequence length × hidden dim × 层数

粗略估算公式(Transformer):

MemactBLdnlayersBprecisionc\text{Mem}_{\text{act}} \approx B \cdot L \cdot d \cdot n_{\text{layers}} \cdot B_{\text{precision}} \cdot c

其中 cc 是一个常数(10-20,和具体实现有关)。举例:7B 模型,batch=4,seq=2048,BF16——激活约 8-12 GB

省激活的神器是 gradient checkpointing

  • 做法:只保存部分层的激活,反向传播时重新前向计算丢弃的层
  • 代价:算力增加约 20-35%
  • 收益:激活显存降低 约 50-70%
model.gradient_checkpointing_enable(
    gradient_checkpointing_kwargs={"use_reentrant": False}
)

KV Cache(推理时的大户)

训练期用不到 KV cache(因为每步都要重算所有位置的 attention)。推理期它会占很大的显存:

KV cache=2BLnlayersnheadsdheadBprecision\text{KV cache} = 2 \cdot B \cdot L \cdot n_{\text{layers}} \cdot n_{\text{heads}} \cdot d_{\text{head}} \cdot B_{\text{precision}}

其中 2 来自 K 和 V 两份。对 7B 模型、batch=1、seq=4096,KV cache 约 2 GB。做 SFT 时要在训练结束后关闭 use_cache 以节约显存:

model.config.use_cache = False      # 训练时
# 训练完成、准备推理时
model.config.use_cache = True

一个可用的粗略估算

整合以上,7B 模型 + LoRA + batch=4 + seq=2048 的显存估算:

组件计算显存
模型参数(BF16)7B × 214 GB
可训参数(LoRA, r=32)~40M × 20.08 GB
梯度(LoRA)~40M × 20.08 GB
优化器状态(AdamW FP32)~40M × 80.32 GB
激活(gradient_ckpt=True)~5-8 GB
CUDA 内核 + 杂项~2 GB
合计~22-25 GB

一张 A100-40G 轻松跑。如果换 QLoRA(4-bit),模型权重降到 4-5 GB,总显存掉到 ~15 GB,RTX 4090 / 3090 都能跑。

数据量与 epoch 建议

SFT 不是预训练,过大的数据量和 epoch 反而会损害效果(模型会被训""——失去原有的多样性和推理能力)。

数据量建议 epoch备注
< 1K 条3-5典型场景:风格适配、格式约束;LIMA 实验
1K-10K 条2-3最常见场景:领域助手、客服机器人
10K-100K 条1-2指令跟随类通用数据
> 100K 条1接近预训练,做多 epoch 易过拟合

停止训练的信号

  • eval_loss 开始上升,而 train_loss 继续下降 → 过拟合
  • train_loss 持续下降但生成质量反而变差(用 held-out 评估集验证)
  • 输出的多样性显著降低(模型开始重复相同表达)

Batch size 与梯度累积

物理 batch size(一次前向传播的样本数)受显存限制,有效 batch size(梯度更新使用的样本数)决定训练稳定性。通过梯度累积(gradient accumulation)解耦两者:

effective_batch_size=per_device_batch×grad_accum_steps×num_gpus\text{effective\_batch\_size} = \text{per\_device\_batch} \times \text{grad\_accum\_steps} \times \text{num\_gpus}

经验有效 batch size

场景有效 batch size
LoRA 小规模 SFT(1-10K 样本)16-32
LoRA 中规模 SFT(10-100K)32-64
全参微调64-128
预训练512-1024+
# 单卡 A100-40G 上 Qwen2.5-7B 的 LoRA 典型配置
per_device_train_batch_size=4,     # 物理 batch
gradient_accumulation_steps=8,      # 累积 8 步
# → 有效 batch size = 4 × 8 = 32

学习率 × batch × 数据量

一个经验规律(平方根规则):

lrnew=lrbasebatchnewbatchbase\text{lr}_{\text{new}} = \text{lr}_{\text{base}} \cdot \sqrt{\frac{\text{batch}_{\text{new}}}{\text{batch}_{\text{base}}}}

即 batch 增大 4 倍,lr 大约增大 2 倍(而不是 4 倍)。这对多卡训练尤其重要——从 1 卡加到 4 卡时别忘了调整 lr。

LoRA 的 lr 基准

  • 小数据(<10K):1e-42e-4
  • 中数据(10-100K):2e-45e-4
  • warmup 比例:10%(第一个 epoch 的前 10% 步数线性升温)

训练时长预估

Qwen2.5-7B + QLoRA + A100-40G 为例的粗略估算:

数据量seq_lenbatch时长
1K 条20484~15 分钟/epoch
10K 条20484~2 小时/epoch
100K 条20484~20 小时/epoch

实测速度(samples/sec)受以下因素影响很大:

  • packing 是否开启(开了 1.5-3×)
  • gradient_checkpointing(关了快 20-30%,但显存吃紧)
  • unsloth(约 2× 加速)
  • Flash Attention 2(开了快 10-30%)

何时需要多卡

LoRA/QLoRA 场景下,大部分 7B-14B 的微调可以在单卡完成。真正需要多卡的场景:

场景推荐方案
数据量 > 50K 且要 > 1 epoch多卡 DDP(数据并行)
全参数微调 7B+DeepSpeed ZeRO-2 或 ZeRO-3
全参数微调 32B+DeepSpeed ZeRO-3 + CPU offload
超大模型 70B+ 的 LoRAFSDP(PyTorch 原生)

DDP / FSDP / DeepSpeed 简介

Distributed Data Parallel:每张卡完整复制模型和优化器状态,只对梯度做 all-reduce 同步。

  • 优点:简单,速度快
  • 缺点:每张卡都要放得下完整模型
  • 适用:LoRA / QLoRA 的多卡加速(省的是时间,不是显存)

ZeRO(Zero Redundancy Optimizer):把优化器状态、梯度、参数分片到多张卡上。

阶段分片对象显存节省
ZeRO-1优化器状态~4×
ZeRO-2优化器 + 梯度~8×
ZeRO-3优化器 + 梯度 + 参数~N×(N = GPU 数)
  • 适用:全参数微调 7B+
  • 代价:通信开销增大,小 batch 时速度下降明显

Fully Sharded Data Parallel:PyTorch 2.0+ 原生分布式方案,和 ZeRO-3 思路相似。

  • 优点:原生 PyTorch 集成
  • 缺点:生态不如 DeepSpeed 成熟
  • 适用:大模型 LoRA(70B+)、全参数微调(32B+)

深入阅读:分布式训练的数学细节(gradient synchronization、parameter sharding、activation checkpointing 的 tradeoff)在姊妹课程 大语言模型后训练实践 中有单独一讲。本讲到此为止,够你用单卡 / 双卡跑完绝大多数 SFT 场景。

一份"我该怎么配"速查表

model: Qwen/Qwen2.5-7B
method: LoRA (BF16, r=32)
per_device_batch: 4
grad_accum: 8
seq_len: 2048
packing: true
gradient_checkpointing: true
# → 有效 batch 32, 显存 ~32 GB
model: Qwen/Qwen2.5-7B
method: QLoRA (NF4, r=32)
per_device_batch: 2
grad_accum: 16
seq_len: 2048
packing: true
gradient_checkpointing: true
# → 有效 batch 32, 显存 ~18 GB
model: Qwen/Qwen2.5-1.5B
method: QLoRA (NF4, r=16) + unsloth
per_device_batch: 2
grad_accum: 8
seq_len: 1024
packing: true
gradient_checkpointing: true
# → 有效 batch 16, 显存 ~12 GB
model: Qwen/Qwen2.5-32B
method: LoRA (BF16, r=64)
distributed: DeepSpeed ZeRO-2
per_device_batch: 2
grad_accum: 4
num_gpus: 4
seq_len: 2048
# → 有效 batch 32, 总显存 ~4 × 40 GB

本节小结

概念要点
显存四部分参数 + 梯度 + 优化器 + 激活
优化器是大头AdamW FP32 = 8 bytes/参数,是参数本身的 4 倍
LoRA 省的是后三部分梯度、优化器、激活都只作用于 ~1% 的可训参数
Gradient checkpointing省 50-70% 激活,代价是慢 20-35%
有效 batch sizeper_device × grad_accum × num_gpus
SFT 数据量1K-10K 条最常见,2-3 epoch 够了
lr 经验值LoRA 2e-4,全参 2e-5,batch 翻倍 lr 2\sqrt{2}
多卡入口DDP(LoRA 加速)/ ZeRO(省显存)/ FSDP(超大模型)