大数跨境
0
0

RFT(拒绝采样微调):提升大模型推理能力

RFT(拒绝采样微调):提升大模型推理能力 极市平台
2024-06-28
2
↑ 点击蓝字 关注极市平台
作者丨绝密伏击@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/703848627
编辑丨极市平台

极市导读

 

论文提出了应用RFT(Rejection sampling Fine-Tuning)拒绝采样来生成和收集正确的推理路径,以此作为增强的微调数据集。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

论文:https://arxiv.org/pdf/2308.01825

github:https://github.com/OFA-Sys/gsm8k-ScRel

简洁版

如果对RFT已经有所了解,那么我们就开门见山,直接介绍论文的主要方法,如下图所示。

图1:使用RFT优化数学推理能力

整体思路是使用多个小模型(比如Llama1/2-7b/13b)生成推理路径,经过质量筛选和多样性筛选之后,用于更大模型Llama2-70B的SFT。

左边是是小模型RFT的流程,右边是更大模型Llama2-70B的SFT流程。具体步骤如下:

1、训练一轮小模型和大模型: 首先使用问题、回答 的数据集(比如GSM8K)训练一轮小模型和大模型;

2、选择推理路径: 然后使用小模型对每个 生成多条推理路径 , 经过质量过滤(保留答案正确的推理路径)和多样性控制(同一个问题尽量选择多样性好的回答), 生成新的数据集 ;

3、使用新的数据集微调模型: 使用新的推理数据集 微调一轮小模型, 并将其用于微调右边的 Ilama2-70B;

4、重复2、3步骤

这里面关键点就是步骤2,选择推理路径。具体做法如下图所示:

图2:选择推理路径

图中的上面部分是算法的整体流程,下面部分是一个具体的示例。

选择推理路径的目的从一组给定的推理路径中选择出具有不同计算过程的路径,以增加模型在训练时的泛化能力。下面结合图2介绍一下大概流程:

  1. 对于每一个问题 , 初始化一个空列表 来存储被选中的推理路径。
  2. 初始化一个集合 来存储已经出现过的方程集合。
  3. 遍历问题 的所有推理路径列表 :
  • 对于每个推理路径 :
  • 使用 get_equation 函数提取出推理路径中的方程列表。
  • 检查方程列表是否不在集合 中(get_equation ):
  • 如果不在, 将推理路径 添加到 列表中, 并将方程列表添加到集合 中。
  • 如果在, 寻找 中已有的推理路径 , 使得 的方程列表与当前的方程列表相同。
  • 如果找到, 计算当前推理路径 与已选推理路径 的距离(Levenshtein distance), 并与之前找到的具有最大距离的推理路径进行比较。如果当前的距离更大, 则用当前推理路径替换之前找到的路径。

下面是一个具体例子(对应图2中的下面部分):

假设问题 目前已经选择的推理路径是:

方程集合是:

当前推理路径 对应的方程是 , 和 具有相同方程的是推理路径

选定推理路径 , 如果:

那么替换 , 这时新的推理路径是:

以上就是论文的主要方法。下面是主要结论:

  • 使用组合模型(Llama1/2-7b/13b)在 k=100 (即每个问题每个模型都生成100条推理路径)时,经过算法“选择路径推理”过滤后,能生成12.84条不同的推理路径;而7B模型只能生成5.25条,33B只能生成2.78条。这也表明了组合模型可以生成多样性更好的推理路径,小模型相比较于大模型能够生成更多的推理路径
  • 使用小模型生成推理路径的方式(图1所示方法),在7B上的准确率从35.9% 提升到49.1%;在33B上的准确率从54.6% 提升到57.9%

备注:之所以在33B上的提升幅度变小,是因为33B基座模型的推理能力更强。如果将GSM8K换成更难的数据集,在33B是上的提升幅度可能会更大。

实验结果

下面是具体的实验结果:

RFT-U13B(Llama1/Llama2-7B/13B)和RFT-U33B的结果相差不大,但是和SFT相比,提升都很明显。

下图是不同模型采样出的推理路径:

33B在k = 100时,只能采样出2.78条不同的推理路径,而U13B(Llama1/Llama2-7B/13B)能采样出12.84条不同的推理路径。

以上是论文思路的概述,若想了解更多细节,请继续阅读下文。

背景

论文提出了应用RFT(Rejection sampling Fine-Tuning)拒绝采样来生成和收集正确的推理路径,以此作为增强的微调数据集。RFT能够生成包含更多独特推理路径的增强样本,更大幅度地提升LLMs的数学推理性能。对于性能较差的LLMs,RFT带来的改进更为显著。此外,论文将来自多个模型的拒绝样本结合起来,使LLaMA-7B在GSM8K上的准确率达到了49.3%,这明显优于有监督微调(SFT)的准确性35.9%。

在介绍论文方法之前,先介绍下什么是RFT,以及RFT和SFT的区别。

什么是RFT?

RFT(Rejection sampling Fine-Tuning)和SFT(Supervised Fine-Tuning)是两种用于微调机器学习模型的方法,特别是在自然语言处理领域。

SFT是一种常见的微调方法,主要步骤如下:

1. 数据收集:收集大量的标注数据,这些数据通常由人类专家根据特定任务进行标注。

2. 模型训练:使用这些标注数据对预训练模型进行微调,使其在特定任务上表现更好。

3. 评估和优化:通过验证集评估模型性能,并根据结果进行优化。

SFT的优点是相对简单直接,只需要高质量的标注数据即可。然而,SFT也有一些局限性,比如对标注数据的质量和数量要求较高。

RFT是一种更为复杂的微调方法,主要步骤如下:

1. 数据生成:首先使用预训练模型生成大量的候选输出。

2. 筛选过程:通过某种筛选机制(如人工评审或自动评分系统)从这些候选输出中挑选出高质量的样本。

3. 模型训练:使用筛选后的高质量样本对模型进行微调。

RFT的关键在于筛选过程,这个过程可以显著提高数据的质量,从而提升模型的性能。筛选机制可以是人工的,也可以是基于某种自动化评分系统的。

区别

1. 数据来源:

- SFT:依赖于预先标注好的高质量数据。

- RFT:通过生成大量候选输出,然后筛选出高质量样本。

2. 数据质量控制:

- SFT:数据质量主要依赖于标注过程的质量控制。

- RFT:通过筛选机制来确保数据质量,即使初始生成的数据质量不高,也可以通过筛选提高。

如何将RFT用于数学推理任务?

RFT的核心思想是利用已有的监督模型来生成新的数据样本,如果将其用于数学推理任务,那么可以通过选择正确的推理路径来增强模型的训练数据集。

1. 生成候选推理路径:使用一个已经通过监督微调(SFT)训练好的模型来生成针对训练集中每个问题的多个候选推理路径。这些路径包括一系列计算步骤,旨在解决问题。

2. 筛选正确路径:从生成的候选路径中筛选出那些能够正确推导出问题答案的推理路径。

3. 去重和多样化:进一步从筛选出的正确路径中选择具有不同计算过程或表达方式的路径,以增加数据集的多样性。这有助于模型学习不同的解决问题的方法。

4. 微调:使用这些经过筛选和去重的推理路径作为新的训练数据,对原始的监督模型进行进一步的微调。

5. 提高泛化能力:通过引入多样化的推理路径,RFT旨在提高模型在未见过的问题上的泛化能力。

将RFT用于数学推理任务,可以利用模型自身生成的数据来增强其推理能力,同时避免了昂贵的人工标注成本。这种方法特别适用于那些难以通过增加监督数据量来提升性能的场景,因为它允许模型从未充分利用的训练数据中学习新的推理策略。

和SFT相比较,RFT具有以下几点优势:

1. 数据增强的有效性:RFT通过拒绝采样的方式,使用监督模型生成并收集正确的推理路径作为额外的微调数据集。这种方法可以在不增加人工标注工作量的情况下,增加数据样本,从而提高模型性能。

2. 推理路径的多样性:RFT特别强调通过增加不同的推理路径来提高LLMs的数学推理能力。这意味着RFT能够提供多种解决问题的方法,有助于模型在面对新问题时有更好的泛化能力。

3. 对性能较差模型的提升效果:论文中提到,RFT对于性能较差的LLMs提升更为明显。这表明RFT可能是一种更为有效的改进手段,特别是对于那些需要显著提高推理能力的模型。

4. 组合多个模型的优势:RFT可以通过组合来自多个模型的拒绝样本来进一步提升性能。这种方法使得LLaMA-7B在GSM8K数据集上的准确率从SFT的35.9%显著提高到49.3%。

5. 计算资源的经济性:尽管RFT在生成样本时可能需要较多的计算资源,但在训练阶段相比从头开始预训练一个LLM来说,它是一种更为经济的方法。这使得RFT成为一种可行的、成本效益更高的改进模型性能的手段。

6. 减少过拟合:RFT通过引入多样化的推理路径,有助于减少模型在训练数据上的过拟合,特别是在大型模型中。

一、引言

为了提升大模型的推理能力,可以使用有监督微调SFT和上下文学习ICL(In-context learning)等方法。

论文分析了SFT和ICL的性能。在一个给定的区间内,预训练损失与SFT和ICL准确性大约呈负线性相关。其次,模型性能与SFT数据量呈对数线性关系,而随着预训练模型的改进,增长逐渐减小。

因此,为了提升模型性能,可以通过降低预训练损失、提升SFT数据量。降低预训练损失需要大量的工作,而提升SFT数据量似乎更简单一些。为此,论文提出了利用模型本身来生成更多的SFT数据,以增强模型的推理能力。

论文在SFT模型上应用RFT,以采样和选择正确的推理路径作为增强数据集使用这些增强的数据集来微调基础的LLM,这将比SFT取得更好的性能。

作者发现影响RFT性能的关键因素是独特的推理路径数量可以通过多次采样或结合多个模型输出来增加数量

论文讨论了RFT之所以有效的原因是因为它提供了多种推理路径,使得LLM具有更好的推理泛化能力。论文还讨论了RFT在计算资源方面比预训练便宜得多,而通过较低的预训练损失训练一个LLM可以显著提升模型性能。

图1-1:论文实验主要结果
  • 当预训练损失变小时(即预训练模型变得更好),在一定范围内,SFT和ICL的模型推理性能会线性提高。SFT的性能改进速度比ICL慢。
  • SFT随着数据量的增加呈对数线性方式提升。随着预训练模型的改善,增加数据量的好处逐渐减弱。
  • 随着不同推理路径数量的增加,RFT的模型性能也会得到提高。RFT的性能改进速度比SFT慢。
  • 从多个模型中组合拒绝采样进一步增强了RFT的性能,使得LLaMA-7B达到了49.3%的准确率(相比于SFT提高了13.4%),LLaMA2-7B达到了50.3%(+8.7%),LLaMA-13B达到了52.1%(+9.1%),以及LLaMA2-13B达到了55.4%(+5.4%)。

二、影响大模型数学推理能力的关键因素

我们期望预训练模型能够从有监督的推理数据集 中学习到推理能力。该数据集由 定义,其中 是一个问题, 是一条COT推理路径, 是一个数值答案。我们在数据集 上进行有监督微调以获得一个SFT模型 。然后使用 通过贪心解码来生成测试集中的推理路径和答案。

2.1 模型准确率 VS pretrain loss

之前的研究表明,更大的LLM表现出更好的推理能力,但是论文作者发现 Llama 比GPT-3表现更好,这表明模型参数量不应该作为衡量推理能力的唯一指标。虽然 LLMs 具有不同的架构、模型参数和预训练 token 数,但作者发现预训练损失才是衡量数学推理能力的稳定性指标,因此论文使用预训练损失来代表模型,而不是使用它们的参数量和 token 数。

下图反应了预训练损失和模型性能的关系。

图2-1:预训练损失和模型性能的关系

通过图2-1,可以发现:

  • 预训练损失与SFT的准确率呈负线性相关
  • SFT在性能上始终优于ICL,而当预训练损失较低时,这种改进会减小。

从观察来看,提高推理能力的一种有效方法是训练一个具有更低预训练损失的更好基础模型(预训练就是您所需要的!)。具有较低预训练损失的模型从微调中获得的改进较小,这可能是因为这些模型在预训练期间已经获得了更多的推理能力,并且有监督数据能提供的信号较少来指导它们。

2.2 模型准确率 VS 有监督数据量

SFT 确实提高了LLM的推理能力,但是我们更想知道监督数据的数量如何影响模型的改进。

论文对Llama和Llama2进行了微调,使用了来自GSM8K训练集的{1, 1/2, 1/4, 1/8, 1/16, 1/32}倍数据量。下图是数据量和模型准确率的关系。

图2-2:模型准确率和数据量的关系

从图3中,我们可以观察到:

  • 模型性能与数据量的关系呈对数线性关系。
  • 更好的模型需要更多的数据量才能超过其ICL性能。
  • 当监督数据量翻倍时,更好的模型受益较少。

对数线性关系在训练数据的{1, 1/2, 1/4, 1/8}比例范围内是稳定的。从图中可以直观地看出,扩大训练数据集可以提高性能,特别是对于较差的模型而言。对于较好的模型,其好处较少,这呼应了更好的模型在大语言预训练期间已经学到了更多的推理能力。

2.3 模型准确率 VS 增强数据量

增加数学推理数据集往往非常困难,特别是提出新问题。对于一个受过良好教育的学生来说,每天解决数百个数学应用题很容易,但是很难想出多样化和具有教育意义的数学问题。因此,论文的方向改变为使用现有数据来扩充新数据

论文尝试了扩充新query,但是这些方法与SFT相比并没有显著提升,而RFT却是一种简单有效的方法,可以扩充新的推理路径并提高模型性能。作者发现,影响RFT的关键因素是独特的推理路径数量

RFT(Rejection Sampling Fine-tuning)

模型 获得了COT推理能力, 现在使用 来生成更多正确的推理路径 作为新的训练数据集。

对于每个 , 生成 个候选推理路径和答案, temperature为 0.7 。

首先, 根据Python评估过滤掉具有错误答案的推理路径 或基于Python评估的错误推理路径)。每个推理路径包含一个方程列表 , 我们为每个不同的方程列表选择一个推理路径 作为增强数据, 并删除与其他具有相同方程列表的推理路径, 以消除相似的推理路径。不同数字顺序(例如 或不同方程顺序(例如 被视为不同的方程。这对于模型了解这些顺序可以交换是有帮助的, 并且仅凭一个问题的一个推理路径很难让模型学习到这一点。论文将 定义为增强数据集, 其中:

在预训练LLM上使用 进行微调, 得到新的RFT模型 。表1中列出了在Llama和Llama-2上使用 采样候选推理路径的RFT结果。

表1:SFT和RFT在GSM8K上的准确率对比

设置 7B 7B-2 13B 13B-2 33B
pretrain-loss 1.8 1.75 1.73 1.68 1.62
ICL 11.0 14.6 17.8 28.7 35.6
SFT 35.9 41.6 43.0 50.0 54.6
RFT(k = 100) 41.7 47.5 49.1 54.8 54.5
每个问题的正确推理路径数量 53.3 60.8 62.5 71.6 88.7
每个问题的不同推理路径数量 5.25 5.19 5.26 5.29 2.78

对于33B模型,与SFT相比,RFT没有提高性能。主要原因是来自拒绝采样的增强样本。作者发现更好的模型可以为每个问题生成更多的正确推理路径。对于LLaMA33B-SFT,它平均每问题可以生成88.7条正确的路径。然而,它过度拟合了训练集,并且在训练集问题上难以生成更多样化的路径(只能生成2.78条不同的推理路径)。

对于33B的拒绝采样非常耗时,论文尝试使用更大的temperature=1.0来解码LLaMA-33B-SFT模型,它生成了82.4条正确的路径和每问题4.77条独特的路径,这比使用temperature=0.7更多样化,但仍然不如7B和13B模型的多样性。

模型准确性 VS 拒绝采样数据量

为了了解RFT的性能,论文设置了多组不同的采样 k,在k=1、3、6、12、25、50、100的情况下比较RFT的效果。此外还设置了另一个 k = 100 的情况,在这种情况下不删除任何推理路径,称之为 no dedup。图2-3是RFT在不同采样次数k下的GSM8K性能表现。

图2-3:RFT在不同采样次数k下的GSM8K性能表现

图2-3中列出了不同k值的RFT结果。使用 k = 100 的 RFT 和 no dedup 的结果相差不大,这表明基于独特的推理路径数量来估计RFT性能比基于RFT增强样本数量更好。此外,对于4个模型中的3个,使用去重操作具有更好的性能,并且需要更少的训练时间。

当使用 k = 3 时,RFT比SFT稳定地高出2个点。对于大多数情况来说,使用较大的k会取得更好的性能。然而,随着 k 的翻倍,RFT 的优势在逐渐减小。

表2中计算了不同 k 对应的每个问题的不同路径数。我们可以看到,随着 k 的增长,不同的推理路径数量并没有快速增长。而从图3中我们知道,训练样本翻倍可以带来线性的性能提升。然而推理路径翻倍带来的提升幅度小于训练样本翻倍,这是因为获得不同的推理路径并不能得到任何新问题

表2:不同SFT模型针对每个问题生成的不同推理路径

k 7B 7B-2 13B 13B-2 33B
1 1.17 1.19 1.15 1.18 1.06
3 1.44 1.47 1.41 1.45 1.16
6 1.74 1.78 1.69 1.76 1.28 1.76
12 2.20 2.23 2.11 2.21 1.46
25 2.93 2.93 2.88 2.94 1.77
50 3.94 3.91 3.90 3.94 2.19
100 5.25 5.19 5.26 5.29 2.78
400(U13B) 12.84 12.84 12.84 12.84 12.84
500(U33B) 13.65 13.65 13.65 13.65 13.65

通过结合来自多个模型的拒绝采样样本, 上面的实验结果证明了在数学推理方面的性能提升, 这得益于拒绝采样。拒绝采样可以丰富训练数据, 使其包含多种计算过程的推理路径。而从一个单一的 SFT模型采样的推理路径可能缺乏多样性。因此, 论文通过利用不同模型聚合的拒绝采样推理路径来进一步改善数学推理性能。将两个最终的数据集表示为 , 它们分别来自于不同模型的拒绝采样。

, 其中U表示特定尺寸的模型, 7B/13B/33B表示LLaMA-7B/13B/33B,而7B2/13B2表示LLaMA2-7B/13B。表示一个聚合过程, 在这个过程中, 不同集合的所有推理路径首先被组合起来, 然后应用图2中的“选择推理路径”算法来消除关于方程形式和顺序具有相同计算过程的推理路径。

通过下面的图2-4可以看出, 相比于仅用单个模型的数据集进行微调, 使用聚合数据集 可以在不同规模的模型上一致性更好的效果。RFT在这两个增强数据集上的表现缩小了 SFT和RFT k=100之间的性能差距, 这意味着这些组合的增强数据集提供了足够的推理数据以填补预训练的差距。

图2-4:从多个模型中获取拒绝采样样本

在33B模型上应用RFT 非常消耗资源, 并且需要一个适度的网格搜索才能取得比SFT更好的效果。然而, 进行微调的计算成本与在33B上采样100次的成本相当, 并取得了更好的性能。

另一个现象是, 将 包含在聚合中几乎不影响性能。为了更全面地分析结果, 表2中计算了每个问题平均推理路径的数量, 而图6中中显示了不同推理路径的来源。

在表 2 中, 的平均推理路径数量远超单个模型的数量, 而 仅比 多出0.81个推理路径。与此同时, 如图2-5所示, 大小在13B及以下的模型在大约15%的情况下为 提供了独特的推理路径。然而, 只有6.5%的推理路径可以独自从LLaMA-33B-SFT模型获得。这表明, 当采样训练问题时, 33B的SFT模型提供的推理多样性有限。这一发现与表1的结果一致, 表明33B模型(可能还有65B和70B模型)能够很好地记住人类注释的推理路径。对于65B 模型, 论文发现使用 并没有比SFT提高性能。原因可能是更好的模型从有监督样本中获益较少, 因为它在预训练期间已经学到了更多的推理能力。

图2-5:每个模型为D'U33B提供的推理路径的比例。例如,在黄色的部分中,D'U33B中的15.5%的推理路径只能从LLaMA2-13B-SFT结果中找到。

总的来说,我们可以得出结论:

  • 通过来自SFT模型拒绝采样出的多样化推理路径提高了LLM的数学推理性能,聚合更多样化的推理路径可以进一步提高性能。
  • 不同的SFT模型可以通过拒绝采样生成不同的推理路径,从而为RFT提供更多样化的训练数据。更大参数规模的LLM可能会因为过拟合问题导致生成的独特推理路径减少。

A、实验细节

论文使用NVIDIA A100 GPU在GSM8K数据集上进行微调,训练3个epoch,批处理大小为128。

对于7B和13B模型,论文使用了8个GPU;对于33B模型,使用16个GPU;而对于65B和70B模型,使用32个GPU。采用的学习率为2e-5。在最后一个epoch评估结果。

下面是详细的实验结果:

下面是通过拒绝采样为RFT生成的不同推理路径的案例,计算结果用红色突出显示。

正如前面所述,RFT考虑了关于方程形式或顺序的不同计算过程的推理路径,从而得到正确答案。在上图的案例中,所有来自RFT的推理路径都得到了正确的答案10,而推理的计算过程是多种多样的。路径1和2以及路径4和5在红色突出显示的方程形式上有所不同。路径1和2展示了一个两步计算推理过程,而路径4和5改变为一个一步计算推理过程。这个案例说明了拒绝采样可能能够提供更多的监督信号,以提高数学推理性能。从LLM本身采样过滤后的推理路径与来自人类注释的推理示例具有相似的质量

参考

Scaling Relationship on Learning Mathematical Reasoning with Large Language Models(https://arxiv.org/pdf/2308.01825)

https://github.com/OFA-Sys/gsm8k-ScRel(https://github.com/OFA-Sys/gsm8k-ScRel)


公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读919
粉丝0
内容8.2k