大数跨境
0
0

Transformers是如何实现大模型的投机采样的?

Transformers是如何实现大模型的投机采样的? 极市平台
2023-09-06
0
↑ 点击蓝字 关注极市平台
作者丨良睦路程序员@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/653935025
编辑丨极市平台

极市导读

 

一文搞懂投机性解码原理和具体流程。>>关注公众号,后台回复「极市干货」即可获取最新整理CV知识内容合集

背景

今天看到Niels Rogge转发了一个推文,介绍了投机性解码,感觉非常有意思,就研究了一下。

恰好,我又看到新智元的一篇文章:【不用4个H100!340亿参数Code Llama在Mac可跑,每秒20个token,代码生成最拿手|Karpathy转赞】,仔细看了一下,好家伙,竟然就是上面那篇推文介绍的东西。

因此,打算从代码角度,好好研究一下投机性解码。

投机性解码原理

简单的来说:

  1. 就是使用一个小模型来做草稿,然后使用大模型做纠正检查。
  2. 小模型的参数量要远小于原模型参数量一个级别才效果明显。
  3. 小模型和原模型的tokenizer最好一模一样,不然会增加额外的解码、编码时间。

具体流程如下:

假设,让小模型预测5步

第一步

原始token序列为蓝色序列

第二步

使用辅助模型生成5个新的token序列(红色)

流程是这样的

第三步

将蓝色序列和红色序列拼接在一起,放入原始模型中,并且生成一个绿色结果

第四步

考虑到大模型的最后一步都是使用矩阵计算概率,那么在第三步中,看似生成了一个绿色结果。实际上,利用了矩阵计算的并行性:一步计算,就可以验证小模型生成的5个结果对不对。

第五步

看第四步生成的结果,可以发现:

  1. 小模型生成的第一个token是466,但是大模型生成的第一个token是651。
  2. 小模型生成的序列中,只要有一个错了,那后面就不能要了。
  3. 因此不用小模型结果,使用大模型的第一个结果651(绿色小块)作为本轮的结果。
  4. 以此循环,直到遇到结束符才停止。

第四步扩充

可能还是有人没搞懂:明明大模型利用蓝色和红色块推理得到了绿色338,但是没怎么用到338。而且怎么一次性把绿色块的651、428、287、475、340计算出来的。

这个疑惑的产生,主要是因为没有把casualLM类型的模型搞懂。casualLM类型的模型,都是预测下一个token的概率。而且使用的是矩阵计算,所以一次性把所有token对应的下一个位置,都预测出来了。

第四步里面的绿色的651是怎么产生的:

第四步里面的绿色的428是怎么生成的:

第四步里面的绿色的287是怎么生成的:

第四步里面的绿色的475是怎么生成的:

第四步里面的绿色的340怎么生成的:

第四步里面的绿色的338是怎么生成的:

这个时候,你在回看第四步的图(这里复制了一份),发现:

  1. 小模型预测的第一个tokenid就已经错了(红色的466),大模型预测的是(绿色的651)。
  2. 这种自回归序列模型,一步错,步步错。因此,虽然后面的小模型预测了一个340(红色)和大模型预测的340(绿色是一样的),但是完全不能用。
  3. 考虑到大模型基于原始的token list(蓝色),预测了651(绿色)。那就把651拿来用了。
  4. 虽然小模型预测了5步,全都错了,但是也不亏。因为小模型的5步的计算时间,远远小于大模型一次预测的时间。

解释

可以想象一下:

  1. 大模型直接利用原始token(蓝色序列),要预测新的5个token,计算起码需要跑5次。
  2. 先使用小模型预测5个试一试,然后大模型借助矩阵计算的并行特性,一次性就可以验证这5个中,前几个都是对的。
  3. 如果有对的,那节约的时间可不是一点点(因为小模型远小于大模型,所以小模型消耗的时间基本可以忽略不记)。
  4. 这个思想很简单,举个例子:树上全是枣子,旁边又有竹竿,那你肯定拿起竹竿,在空中挥了一挥。能打到枣子,算走运了,没打到枣子,也不亏。那么这个投机生成也是这个道理。

代码部分

  1. 虽然这个逻辑很简单,而且也说了,是使用了矩阵的并行性,但是在刚开始,我并不知道他在代码里面是如何实现的。
  2. 虽然投机采样也就是在8月31号才火起来,但是代码在4月份的时候,就已经在transformers包里面实现了。

代码链接为:https://github.com/huggingface/transformers/blob/4b796978656e461177a83d58ec3c2b06152c63db/src/transformers/generation/utils.py#L4269

如何使用

在huggingface的transformers包里面,已经给到一个使用案例了,代码如下:

from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LogitsProcessorList,
MinLengthLogitsProcessor,
StoppingCriteriaList,
MaxLengthCriteria,
)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
# set pad_token_id to eos_token_id because GPT2 does not have a PAD token
model.generation_config.pad_token_id = model.generation_config.eos_token_id
input_prompt = "It might be possible to"
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
]
)
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
outputs = model.assisted_decoding(
input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

源码解读

是怎么的出来,是使用矩阵的并行原理,可以在这里看到:

            # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.

# 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs:
...
else:
if self.config.is_encoder_decoder:
...
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)

在这里,把原始的token和辅助模型(在huggingface的代码里面,叫辅助模型,但是和上面的小模型是一回事,叫法不一样)生成的token绑定在一起,然后放入原始模型做推理。

在下面的代码块中,把大模型批量处理的结果提取出来,和辅助模型生成的结果做比对。

# 2.2. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
if len(logits_warper) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])

# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)

# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()


误区

其实,刚开始,还以为是把辅助模型生成的新token逐步复制,从batchsize=1变成batchsize=5。但在debug的时候,解除了我的困惑。

因为数据上显示batchsize=1,一直没有变化,就是取了logits的不同位置。

然后想起来CasualLM类型的模型,最后一层都是nn.Linear结构(将hidden_states转换成logits概率),然后就想到了是使用矩阵的并行原理实现的。

参考链接

  1. Niels Rogge的推特链接: https://twitter.com/NielsRogge/status/1697335383166472294
  2. 新智元的那个文章的知乎链接: https://zhuanlan.zhihu.com/p/653729679
  3. huggingface的辅助生产文章链接: https://huggingface.co/blog/zh/assisted-generation

最后

本文图解投机采样的原理,并且介绍了其代码实现,如果有写错的地方,大佬们多多指导~


公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读197
粉丝0
内容8.2k