极市导读
如何通过人类反馈来优化和引导强化学习模型,使其更好地符合人类的价值观和期望? >>加入极市CV技术交流群,走在计算机视觉的最前沿
01 偏好学习与对齐
去年十月份左右开始写了一篇文章 Iterative Preference Learning from Human Feedback: Bridging Theory and Practice for RLHF under KL-Constraint,文章主要包括两个部分,一个是想要做整个 RLHF 的理论,另一个是想展现在 RLHF 中加入 online data 的好处,从而启发大家在这个方向进行探索(而不是不停的设计 offline 算法的变种)。
论文链接:
https://arxiv.org/pdf/2312.11456
最近正好刚刚写完一个关于奖励函数训练的文章,于是决定一起写一篇文章来介绍一下最近的这个工作。
文章中会有一些数学,我尽量把数学细节省略,而集中在解释理论带来的启示与实验结果验证上,如果对理论感兴趣的同学可以参看我们原始的论文,或者也可以私信我和我一起讨论。
我们在这里讨论最经典的 Bradley-Terry 模型下奖励函数优化的框架,最近有一些工作考虑 general preference oracle,例如 IPO,与 nash learning,这篇文章中的技术大部分也可以扩展过去,之后有机会再写一些文章来讨论。
1.1 偏好学习的定义
我们用 表示一个 prompt, 用 表示一个回复。我们假定我们有一个初始模型 , 它经过预训练和监督微调(instruction-following training), 我们的目标是更改它的输出分布, 使得它能被人类所喜欢。
偏好信号: 与 SFT 不同, 在 SFT 中我们从标记数据中学习, 而 RLHF 从相对反馈中学习。形式上, 对于每个比较对 , 偏好预言满足 , 表示在给定提示 的情况下, 优于 的概率。
偏好数据收集: 我们假设提示是从分布 中采样的, 并且我们有两个行为策略 用来采集回复, 那么一个比较对的采集方式可以表示为
其中随机偏好信号 表示 更好, 反之, 表示 更为人喜欢。我们总结大模型训练的不同阶段:
Bradley-Terry model and Reward: 实践中最广泛使用的偏好模型是 Bradley Terry 模型:我们假设存在一个真实奖励函数 ,使得偏好概率满足:
由此,优化 preference oracle 被转化到奖励函数的优化上。现实中,由于 BT model 不能完全刻画人类偏好,也出于防止过拟合有限偏好数据,训练稳定性等等考虑,我们一般优化下面的带正则的目标:
1.2 Offline/Online; Off-policy/On-policy
我们用 Offline 指代从一个给定的偏好数据集中学习,并且在学习中,我们无法进一步让 Human 给出偏好信号,而相应的 Online 指的是我们可以在训练过程中让 Human 为我们标数据。
换言之,区分 online/offline 最关键的在于 preference signal 采集的模式不同。 因此,下面的这些算法都是 offline 的:
-
DPO 直接在给定的数据集上进行偏好学习:offline -
我们从一个给定的数据集训练得到一个奖励函数,并使用 PPO 优化这个奖励函数:offline -
我们从一个给定的数据集训练得到一个奖励函数,并使用 rejection sampling finetuning 优化:offline
一个相关的概念是 on-policy 与 off-policy,我们用 on-policy 指代那些采集数据的策略与要改进的策略是同一个策略的情况,而 off-policy 指代那些使用某个行为策略采集数据,但是用以改进另一个策略的算法。
换言之,区分 on-policy/off-policy 最关键的在于 responses 采集的模型不同。 我们给出以下例子:
-
DPO 是 off-policy 的 -
我们从一个给定的数据集训练得到一个奖励函数,并使用 PPO 优化这个奖励函数:on-policy
02 主要理论结论
这一节我们主要讨论理论结果,我们作如下的假设。
Computational Oracle: 对于任何给定的奖励函数 ,KL-regularized 优化问题有如下闭式解:
其中 。这个闭式解没法直接利用, 因为归一化系数 需要遍历 一个至少是指数大小的空间。这里我们先假设我们可以近似 。
Linear Reward Space: 我们假设 , 其中 是一个特征抽取函数, 这里只是为了叙述简单, 分析可以用这篇论文里的技术直接推到一般函数情况。
2.1 Offline Learning
在 offline 的情况下,我们从偏好数据集中学习,并且无法进一步让 human 打标签。这种情况下,如果我们采用一个比较保守的奖励函数,在每个点上对奖励增加一个不确定性的惩罚,也就是
其中, 是数据集上的协方差估计。直观理解是, 给定一个固定的数据集, 我们无法准确估计每个点的奖励函数, 因此, 对于不确定性大的估计, 我们要针对这一点进行惩罚, 这个惩罚就是这个点在协方差上的投影 , 之所以要减掉一项针对 期望的特征, 是因为在偏好学习中, 我们只能估计特征差的不确定性而不能处理单个回复的不确定性。我们有如下结论。
Theorem 1: 以高概率成立有,
为了进一步解释这个结果,我们一般会进一步作如下假设:
这里的 取决于采集 offline 数据集的策略对 与目标策略对 的分布偏移。很遗憾的是, 在 LLM 场景下, 由于输出长度很长, 这个偏移量通常非常大, 例如, 在 Claude paper 中, 我们可以看到训练过程中策略与起始点的 KL 能到接近 30。
这说明 . 一般来说, 我们很难预期 offline dataset 与我们目标的分布偏移是小的, 因此 通常是一个极大的数。
2.2 批量混合训练(Batch hybrid learning)
在 LLaMA2,Instruct-GPT,Claude 的 technical report 中(怀念大家还愿意分享技术细节的时代),事实上他们的 RLHF 都不是 offline 进行的,总结而言,他们都是进行一种批量混合训练
-
离线阶段:open-source dataset + 使用 及其变种(例如 best-of-n)采集初始数据集 -
在线阶段:在训练的过程中,将一些中间步骤得到的模型进行部署,让 human 对新模型的输出进行标签,这个步骤通常以周为单位迭代,更新相对比较稀疏,所以是 batch。
我们先提供一个直觉上,为什么在线探索对训练有帮助的解释。首先,由于奖励函数也是从预训练大模型训练而来,通常来说,他的泛化性能并不会特别好,因此,如果我们希望奖励函数能够准确的评估某个分布下的回复,那它必然在训练集中需要看见过这类样本,由于初始的模型 还没有被 align 过,它能采样得到的回复的奖励分布主要集中在比较低的区域,大概如下图所示:
因此, 如果我们用 采集数据并训练奖励函数,因为数据集中高奖励的数据比较少,它很可能不太能准确判断两个高奖励的回复谁更好,这样就会出现所谓的 over-optimization 的现象。
我们可以看到, 实线(ground truth ) 与虚线(训练得到的 )在前期分布偏移较小的时候, 增长趋势是一致的, 这说明训练得到的奖励函数能够比较好的区分在 附近的样本, 随着训练过程中, 策略奖励不断提升, 给出的回复与原始策略的偏移也逐渐增大, 最终我们从 offline 数据集训练出来的 彻底崩溃, 导致真实奖励反而下降。
因此,如果我们在训练的中间,用当前的策略收集一些数据,加入到奖励函数训练集中重新训练奖励函数,就能让奖励函数有能力在高奖励区域给出有效的信号,从而获得更好的模型。 这是一个朴素的在线学习带来好处的直觉。
为了分析这样的一个过程, 我们考虑下面的一个算法, iterative DPO:
这里的 DPO 是对于我们假设的 computational oracle 的一个自然近似,我们有如下结论。
Theorem 2最多经历 d 次迭代,我们可以找到一个策略,其满足
这里的 m 是每轮的 batch size。
首先,和纯 offline 算法相比, 在第二项中,现在计算协方差的数据集有两部分
-
Offline dataset
-
Online dataset
我们预期在线数据会很大程度改进 , 这是由于我们试图去 cover 的对象是 , 而我们采集数据的是 。随着训练进行, , 我们 cover 的目标和采集数据的策略之间的分布偏移逐渐减小, 带来的就是一个更小的 。这也和我们刚才的直觉理解是一致的, 我们在算法训练中间得到的模型, 在采集一些高奖励回复( 密度比较大的地方)有优势, 从而能提供更好的数据进行新的奖励函数训练。
另一方面, hybrid learning 的结果里多了一项和奖励函数空间复杂度有关的项, 如果我们相信奖励函数空间有一些低秩结构(discrimination is easier than generation), 特别是考虑到在 LLM 中分布偏移动辁可以到 , 这一项应当会小于 offline learning 的主项。
两个因素结合起来,引入在线学习的 hybrid iterative dpo 效果应当比纯 offline dpo 应当更好。我们在 HH-RLHF 数据集上进行了实验验证,实验中,我们控制总的访问 ground truth oracle的 次数一定来进行公平的比较,结果如下:
可以看到,随着在线探索进行,hybrid-iterative-dpo 获得的模型每轮都有明显的提升。特别的,红色线 (offline dpo) 与 蓝色线 (用更少数据的 offline dpo)的比较可以看出,到达一定界限后,增加 offline 数据带来的增益并不明显。相反,橙色线与绿色线来看,在线探索得到的数据能大幅度提升模型能力。
我们在准备一个 iterative rlhf 的 github repo,实验相关的 reward model 在 huggingface 上可以找到,训练脚本在 GitHub repo,基于 TRL 实现。
2.3 在线探索策略设计
在线学习一个非常重要的点在于用什么策略去进行 online data 的采集,对强化学习文献熟悉的朋友能够知道这是为了更加有效的探索整个状态-动作空间。
请注意, 上面呈现的 选择, 需要一些条件:
-
最开始的 offline data 对我们的目标至少有一定的覆盖, 使得在初始阶段我们的策略是真的在朝向 改进; -
基于第一个条件, 改进的 才能带来更好的覆盖, 从而形成正向循环。
一个自然的问题是, 如果我们没有一个足够好的 offline dataset 来提供这个初始的覆盖条件怎么办? 这时候我们必须对探索策略进行一些设计, 让模型能够自己去收集足够它收敛到 的数据。具体地说, 我们不需要修改算法, 只需要修改策略选取方式。
我们直观解释一下这个选择:
-
首先,第一个策略仍然是基于历史所有数据跑个 dpo 或者 rlhf 的数据,某种意义上是基于数据我们能做的 best-guess; -
关键在于第二个策略的选择,它需要去最大化它与策略 1 的 feature difference 对应的不确定性。换句话说,如果基于历史,我对这个方向仍然数据很少,没有太多信息,我就应当往这个方向多采一些数据来鼓励探索。
很遗憾的是,在线性情形之外,不确定性的估计没有具体的形式,如何在一般的神经网络里做不确定性的估计仍然是一个 open problem。但是我们至少可以原则上对这个理论结果进行分析:
-
我们的探索策略应该是 ,the best policy we can get given the history,的变种,这体现的是对历史数据信息的利用; -
同时,我们的两个策略应该在保证第一点的情况下,尽可能区分开来,使得他们的 difference 比较大,这样能带来比较好的 diversity 与探索。
事实上,LLaMA2 与 Claude 就是这样做的,而instruct-gpt 没有太多细节来判断他是如何收集在线数据,一些常见的策略是
-
使用训练 中不同 training step 的模型变种; -
使用不同的采样温度参数。
一些同期工作会使用 rejection sampling 来进行探索,具体的说,使用 生成 n 个回复,并选择 best-of-n 回复打随机选取的一个回复,或者 worst-of-n 回复。
03 讨论
这里稍微讨论一下,上面我们讨论的内容,和 PPO,DPO,SLIC,IPO 这些算法设计有什么关系。可以注意到的是,我们在进行理论讨论的时候,假定了我们可以很好的近似针对一个固定奖励函数的 KL-regularized 优化问题,而 PPO,DPO,SLIC,IPO 正是这个子问题的现实近似。
事实上,虽然我们主要讨论的是 iterative DPO,我们可以把上述讨论扩展到任何一个这些算法。事实上,谷歌同期的工作也考虑了类似的在线探索框架,而他们的结论是这些算法差异并不大。
但是每一个算法加入在线探索后,都比他们对应的 offline version 要强得多。我和这篇文章的一作这周会写一个比较完整的对整个 RLHF 目前发展现状的 blog,希望吸引更多机智的小伙伴关注到这个令人兴奋的领域。
当时作这篇文章的初衷,一个是为了作 RLHF 的数学理论,在此之前的工作基本是从 dueling bandit 的框架出发,只关注于奖励的最大化,导致理论最优策略是一个确定性的贪婪策略,同时设置与现实实践也比较远。
另一个原因是,我过去两年作 offline 的 RL 理论比较多,同时有做过一些实验验证,对于分布偏移(distribution shift)比较敏感。在之前作 RAFT 算法的时候,我就注意到在大模型里,分布偏移的情况很严重,所以纯 offline 算法必然有其局限性。而当时很多成功的 project,实际上是在 distill GPT4,所以看起来效果很好,但是这种效果是比较难扩展到更大模型的。
如今开源社区有了 mistral 这样比较强的基座模型,即使不 distill 也能训练出很强的模型,此时恰是引入 online 探索的好时机,也希望后续可以看见更多讨论如何改进探索效率的工作。

公众号后台回复“数据集”获取100+深度学习各方向资源整理
极市干货

点击阅读原文进入CV社区
收获更多技术干货

