Skip to content

LLM 训练与对齐

上一篇 LLM 介绍 讲了"LLM 是什么、能做什么"。这一篇把"怎么把它练出来、怎么让它的行为贴合人类意图"讲清楚。范围偏工程实现,不深入推导数学,也不涉及具体框架的逐行教程。

1. 预训练(Pre-training)

预训练是把"知识"装进模型参数的阶段。目标极简:给定一段上文,预测下一个 token。但要做到这件事,工程上需要回答四个问题:数据怎么来、目标怎么定、模型怎么并行、训到什么程度停。

1.1 数据:从网页到训练样本

典型流水线:

  1. 来源 -- CommonCrawl(网页)、Wikipedia、GitHub、ArXiv、Books、StackExchange 等。闭源前沿模型还会用私有授权语料(版权协议文本、出版社、代码仓库等)。
  2. 清洗 -- 去重(MinHash / SimHash 文档级去重)、去 HTML/模板垃圾、语言识别(用 fastText 等模型过滤掉非目标语言)、质量过滤(基于 perplexity 阈值、规则黑名单、合成数据 quality classifier)。
  3. 敏感信息 -- PII 检测与替换、毒化内容(有毒/偏见/违法)过滤、儿童性虐待材料(CSAM)强制过滤。
  4. 分词(Tokenization) -- BPE / WordPiece / Unigram。LLama 系用 BPE,词表大小常见 32k-128k;中文场景需要确认分词器对中文的覆盖率(字节回退会显著增加 token 数,推高成本)。
  5. 打包与混采 -- 把不等长样本 pack 到固定长度(2k/4k/8k/32k tokens),按"数据混合比例"(语料配比)采样,数学/代码/多语言等需要按目标能力加权。

经验法则:数据质量 > 数据量。Meta 的 LLaMA-2 paper 专门强调"清洗后 1.4T tokens 比 7T 原始语料更有效"。

1.2 训练目标:Next-Token Prediction

主流 LLM 都用同一个目标函数:

$$ L(\theta) = - \sum_t \log p_\theta(x_t | x_{<t}) $$

即"在 t 之前所有 token 的条件下,预测 t 时刻 token 的对数似然的负数"。

实现细节:

  • 损失掩码 -- 只在文本 token 上算 loss,prompt 区域通常 mask 掉(对话格式下 "user:" 的部分不参与损失)。
  • 序列打包(Packing) -- 多条短样本拼接到 4096/8192 长度的 context 里,loss 只在真实 token 上算,不计算 padding。
  • 位置编码 -- 主流是 RoPE(Rotary Position Embedding)及其变体(YARN/ALiBi),支持长上下文扩展。

1.3 分布式训练:数据并行 / 张量并行 / 流水并行

模型参数量动辄十亿到千亿,单卡放不下,必须切:

策略切什么典型框架适合规模
数据并行(DP / DDP)切 batch 维度,每卡存完整模型PyTorch DDP,DeepSpeed几十亿参数内
ZeRO(优化器状态切分)切优化器状态 / 梯度 / 参数DeepSpeed ZeRO-1/2/3,FSDP几十到几百亿
张量并行(TP)切矩阵乘法的内维Megatron-LM,DeepSpeed几十亿起,大模型必开
流水并行(PP)切模型层(把不同层放不同卡)Megatron,PipelineParallel几十亿起,显存实在放不下时
3D / 4D 并行上述几者混合Megatron-DeepSpeed70B+ 训练标配
序列并行(SP)切 attention/MLP 的 sequence 维Megatron, Ring Attention长上下文训练(>32k)
专家并行(EP)切 MoE 的 expert 维度Megablocks, DeepSpeed-MoE训练 MoE 模型( Mixtral / DeepSeek-MoE )

实践提示:7B 模型单机 8 卡 A100 通常用 ZeRO-2/3 + TP=1 即可;70B+ 必须 3D 并行 + 节点间 NVLink/IB 高带宽网络。没有 RDMA/IB 互联时,7B 训练都会卡在 all-reduce 上。

1.4 训练中要观察什么

  • loss 曲线 -- 整体下降,最后 1% 应当平缓。过拟合表现为 train loss 持续下降但 eval loss 拐头。
  • 梯度范数(grad norm) -- 突刺往往预示梯度爆炸或数据问题。
  • 学习率调度 -- 主流 warmup + cosine decay;Mistral 7B paper 报告过相对小的学习率(2e-5~1e-5)对长训练更稳。
  • 有效 batch size -- 真实梯度 = 单卡 batch × 累积步数 × DP 数。学习率一般与 batch size 平方根成正比缩放。
  • GPU 指标 -- SM 利用率、显存、HBM 带宽、互联带宽;MFU(Model FLOPs Utilization) 50%+ 是健康线
  • checkpoint 频率 -- 每 N 步 + best-eval 保留;TB 级 checkpoint 必须用并行/分片 IO(NVMe-oF / 并行文件系统)。

2. 监督微调(Supervised Fine-Tuning, SFT)

预训练完的模型是个"文本续写机器",不会按指令回答。SFT 的目标:让模型学会"看到指令 → 输出期望回答"

2.1 数据形态

最小一行 JSONL:

json
{"messages":[
  {"role":"user","content":"..."},
  {"role":"assistant","content":"..."}
]}
  • 单轮 -- Q → A
  • 多轮 -- [U1, A1, U2, A2, ...]
  • 工具调用 -- tool_calls / tool 角色
  • 思考链(CoT) -- 在 assistant 内部嵌入 <think>...</think>

2.2 数据配比与质量

  • 数量级:几千到几十万条都能出效果,质量 >> 数量
  • 多能力混合:通用对话 / 指令跟随 / 写作 / 代码 / 推理 / 工具调用
  • 格式:与目标推理时的格式一致(同种 chat template:ChatML、Llama-3、Qwen 等)
  • 评估:自建 eval 集(覆盖目标能力)+ 公开 benchmark(MMLU / GSM8K / HumanEval / MT-Bench)

不要只盯 loss。SFT 后的 eval 表现往往在 loss 接近饱和时还在快速变化,数据质量提升会带来"loss 没怎么变但能力涨一截"的现象。

3. 对齐(Alignment):让回答"用得顺"

SFT 让模型"会回答",但回答可能冗长、谄媚、不安全、有偏见。对齐把"按指令回答"升级为"按人类偏好回答"。

3.1 RLHF 三步(经典路线)

┌────────┐    ┌────────┐    ┌────────┐
│ SFT 训练│ -> │ RM 训练 │ -> │ PPO 训练 │
│ (上一步)│    │ (奖励模型)│    │ (策略优化)│
└────────┘    └────────┘    └────────┘
  1. 奖励模型 RM -- 给定 prompt + 两个回答 A、B,人类标注哪个更好,训一个分类器输出"回答的标量奖励"。
  2. PPO 优化 -- 用 RM 的分数作奖励,Proximal Policy Optimization 微调 SFT 模型。
  3. KL 惩罚 -- 防止策略偏离 SFT 模型太远(catastrophic forgetting)。

RLHF 训练栈通常由 4 个模型共同占用显存:policy / reference / reward / value(70B 训练时需要 4×H100 节点以上)。

3.2 直接偏好优化(DPO)与变体

DPO 把"先训 RM、再 PPO"压缩成"直接用偏好数据训策略":

  • DPO(2023) -- 把奖励和策略的关系反解,直接在偏好对 (preferred, rejected) 上做 cross-entropy。
  • IPO -- 修正 DPO 在偏好噪声下的过拟合。
  • KTO(Kahneman-Tversky) -- 用"好/坏"二元标签代替成对偏好,采集成本更低。
  • ORPO -- 把 SFT 损失和偏好损失合并,免去 reference 模型。

DPO 路线在 7B-70B 规模上几乎"用一套偏好数据、跑得更快、效果持平或略好"。代价是对偏好数据质量更敏感(噪声会直接折损最终表现)。

3.3 选哪条路

路线训练成本数据成本效果上限适合场景
SFT only内部知识库 / 任务专属
SFT + DPO中高大多数"想要好一点"的场景
SFT + RLHF(PPO)公开产品 / 高安全要求
SFT + RLAIF(AI 反馈)中高中高RM 难以人工标注时(用强模型作 judge)
Online DPO / Iterative DPO中高想要追上前沿但没算力做 PPO

4. 参数高效微调(PEFT)

全参微调几十亿参数成本太高,PEFT 只调一小撮参数,常在 1 张消费级显卡上即可。

方法思路显存省效果接近全参?
Adapter在 Transformer 层插入瓶颈 MLP30-50%
Prefix Tuning在 attention 前缀上训 soft prompt50%+
LoRA在 attention 的 Q/V 矩阵上引入低秩分解 A·B60-80%高(对常见任务)
QLoRALoRA + 4-bit 量化加载基座80%+接近 LoRA
DoRA拆分权重为"方向 + 幅度",LoRA 增强版类似 LoRA通常略好
IA3乘性低秩向量极高

4.1 LoRA 关键参数

  • r(rank) -- 低秩矩阵的秩,常见 8/16/32/64。越大越接近全参,越容易过拟合。
  • alpha -- 缩放系数,通常 alpha = 2 * r 是常见起点。
  • target_modules -- 哪些层加 LoRA,Q/V 是经典起点,Q/K/V/O/FFN 全开能涨一些效果但参数量翻倍。
  • dropout -- 防过拟合,0.05-0.1 常见。

4.2 训练资源粗估(LoRA,7B 模型,4090 24G)

  • 量化基座(4-bit,QLoRA) -- 显存 ~10GB
  • batch=1, seq=2048 + grad-accum=16 -- 训练每步 ~2-4 秒
  • 5k 条样本 -- 大约 1-2 个 epoch 训完,合计 5-15 小时

LoRA 的副作用:合并权重后推理路径不再带 LoRA 矩阵,显存与全参一致;但LoRA 适配器作为基座模型的"插件" 可以被多任务多版本并存,在多领域模型分发场景下非常实用。

5. 评估:什么算"训练好了"

loss 下降只是必要条件,充分条件是模型在目标任务上能用。评估要分三层:

5.1 自动化 benchmark(快速、便宜、可复现)

  • 通用知识 -- MMLU、MMLU-Pro、C-Eval
  • 推理 -- GSM8K、MATH、ARC
  • 代码 -- HumanEval、MBPP、LiveCodeBench
  • 长上下文 -- Needle-in-a-Haystack、LongBench
  • 中文 -- C-Eval、CMMLU

5.2 内部 eval 集(更接近真实业务)

  • 自建 ~100-500 条 prompt,覆盖目标场景(角色、语气、拒答、格式、长尾问题)
  • 用强模型(GPT-4 / Claude / DeepSeek-V3)做 LLM-as-judge,打分或两两比较
  • 关键:评分 prompt 要和目标能力对齐,否则评分噪声巨大

5.3 人工评估(贵、慢,但最终裁判)

  • 5-10 个真实用户盲测
  • 关键维度:有用性 / 准确性 / 简洁性 / 安全性 / 风格
  • 与自动 eval 双轨运行,发现自动化评分"系统性盲点"时回退到人评

最后一步别省。一个 7B 模型 fine-tune 完上生产前,人工跑 50 个真实场景比 5 个 benchmark 都管用。

6. 常见坑

  • Tokenizer 错配 -- 用了与基座不同的分词器,loss 看着正常但能力全失。
  • Chat template 错配 -- 训练时用 ChatML,推理时用 Llama-3,首尾 token 不对齐,模型在"等待 user 发言"上卡住。
  • 学习率过大 -- SFT LR 6e-5 在很多 7B 模型上会让回答变得胡言乱语,先 1e-5 起步。
  • 数据污染 -- eval 集意外混进训练集,benchmark 虚高;尤其是 HumanEval 这种小数据集。
  • Reward hacking -- RLHF 后模型学会"写长且看起来很自信"以骗过 RM,但内容空洞。需要 KL 约束和定期人工抽查。
  • 混合精度 / 数值不稳定 -- bf16 训练时 loss 出现 NaN,通常是 attention softmax 溢出或 grad 累积缩放不对。
  • 遗忘 -- 微调后通用能力下降(MMLU 跌、HumanEval 跌)。可以加 replay buffer(混入 10-20% 通用数据)、限制 LR、限制 epoch。

7. 进一步阅读

  • Megatron-LM 论文与仓库(NVIDIA 大模型训练基线)
  • DeepSpeed + ZeRO 三部曲(Microsoft)
  • Hugging Face TRL 库(SFT/DPO/PPO 一站式实现)
  • LoRA 原论文 + QLoRA 论文(参数高效微调经典)
  • Anthropic 的 Constitutional AI 论文(RLAIF 一条重要分支)
  • LLaMA / Qwen / DeepSeek 各代技术报告(实践范本)

(下一篇将围绕推理优化:量化、KV-cache、speculative decoding、continuous batching。)