TRL(RLHF/DPO 训练)

TRL (Transformer 强化学习)是 HuggingFace 官方用于使用强化学习技术训练语言模型的库。拥有超过 10K 的 GitHub 收藏,它提供了用于大模型对齐的最先进实现,如 RLHF、DPO、PPO、GRPO 等对齐算法。

circle-check

什么是 TRL?

TRL 是许多当今最佳对齐语言模型背后的库。它提供:

  • SFT(有监督微调) — 使用 ChatML 格式的标准指令微调

  • RLHF/PPO — 使用奖励模型的经典近端策略优化(PPO)

  • DPO — 直接偏好优化(无需奖励模型!)

  • GRPO — 群体相对策略优化(DeepSeek-R1 的方法)

  • KTO — Kahneman-Tversky 优化(适用于非配对偏好)

  • 奖励建模 — 从人工偏好数据训练奖励模型

  • IterativeSFT — 更简单设置的在线强化学习

  • ORPO — 赔率比偏好优化

TRL 原生集成了 HuggingFace 生态系统: transformers, peft, datasets, accelerate,以及 bitsandbytes.

circle-info

你应该使用哪种算法?

  • DPO — 最简单、最稳定。当你有配对偏好数据(被选中/被拒绝)时使用。

  • PPO — 功能最强但更复杂。当你有奖励模型或评分函数时使用。

  • GRPO — 非常适合理解/数学任务。DeepSeek-R1 的训练方法。

  • SFT — 在应用任何强化学习方法之前总是从这里开始。


服务器要求

组件
最低
推荐

GPU

RTX 3090(24 GB)

A100 80 GB / H100

显存(VRAM)

16 GB(SFT/DPO 7B + LoRA)

80 GB(7B 完整微调)

内存(RAM)

32 GB

64 GB+

CPU

8 核

16+ 核

存储

100 GB

300 GB+

操作系统(OS)

Ubuntu 20.04+

Ubuntu 22.04

Python

3.9+

3.11

CUDA

11.8+

12.1+

按任务的显存需求

任务
模型
方法
显存(VRAM)

SFT

Llama 3 8B

QLoRA 4-bit

~8 GB

DPO

Llama 3 8B

LoRA

~20 GB

PPO

Llama 3 8B

完整(Full)

~80 GB(2×A100)

GRPO

Qwen 7B

LoRA

~24 GB

SFT

Llama 3 70B

QLoRA 4-bit

~48 GB

DPO

Llama 3 70B

LoRA

~80 GB


端口(Ports)

端口(Port)
服务(Service)
说明(Notes)

22

SSH

终端访问、文件传输、监控

TRL 是一个训练库——它作为 CLI/Python 脚本运行,不需要 Web 服务器。


在 Clore.ai 上的安装

步骤 1 — 租用服务器

  1. 筛选 显存 ≥ 24 GB (RTX 3090、A100 或 H100)

  2. 选择一个 PyTorchCUDA 12.1 基础镜像(base image)

  3. 选择 存储 ≥ 200 GB 用于模型和数据集

  4. 打开端口 22 用于 SSH 访问

步骤 2 — 通过 SSH 连接

步骤 3 — 安装 TRL

步骤 4 — HuggingFace 验证(认证)

步骤 5 — 可选:Weights & Biases 跟踪


有监督微调(SFT)

SFT 始终是在应用任何强化学习技术之前的第一步。

准备你的数据集

SFT 训练脚本


DPO(直接偏好优化)

DPO 是最受欢迎的对齐方法——无需奖励模型,只需偏好对。

准备 DPO 数据集

DPO 训练脚本


PPO(近端策略优化)

PPO 是经典的 RLHF 方法——当你有奖励信号时使用:


GRPO(群体相对策略优化)

GRPO 在 DeepSeek-R1 的推理训练中被使用:


多 GPU 训练

使用 accelerate 进行分布式训练:


使用 TRL CLI

TRL 提供方便的 CLI 命令:


监控训练


Clore.ai 的 GPU 建议

TRL 训练是对显存消耗极高的工作负载之一。根据模型大小和方法选择你的 GPU:

任务
GPU
说明(Notes)

7–8B 的 SFT / DPO(QLoRA)

RTX 3090 24 GB

QLoRA 4-bit 大约需要 ~8 GB;可轻松运行;在 Clore.ai 上约 $0.12/小时

7–8B 的 SFT / DPO(LoRA bf16)

RTX 4090 24 GB

与 3090 相同的显存但计算速度约快 30%;非常适合加快迭代速度

7B 的完整 SFT 或 13B 的 DPO

A100 40 GB

40 GB 可容纳 7B 的全精度训练;ECC 内存可避免静默错误

PPO / 7B 的完整微调,或任何 70B QLoRA

A100 80 GB

PPO 需要在显存中同时放置策略模型和参考模型的两份;80 GB 可在不 OOM 的情况下运行两者

实用提示: 在 RTX 3090 上使用 QLoRA 开始实验——在 10K 个示例上训练 Llama 3 8B 大约需要 ~2 小时。一旦验证了流水线,就迁移到 A100 80GB 以进行全精度运行或训练 70B 模型。

速度指标(Llama 3 8B SFT,QLoRA,batch=4,seq=2048):

  • RTX 3090:约 1,100 tokens/sec 的训练吞吐量

  • RTX 4090:约 1,450 tokens/sec

  • A100 80GB:约 2,800 tokens/sec(全 bf16,无量化)


故障排除

CUDA 内存不足(Out of Memory)

损失为 NaN

DPO: chosen_rewards > rejected_rewards 为 False

训练非常慢

tokenizer.pad_token 警告

权限被拒绝 / HuggingFace 401


保存并共享你的模型


有用的链接

最后更新于

这有帮助吗?