极市导读
一文搞懂投机性解码原理和具体流程。>>关注公众号,后台回复「极市干货」即可获取最新整理CV知识内容合集
背景
今天看到Niels Rogge转发了一个推文,介绍了投机性解码,感觉非常有意思,就研究了一下。
恰好,我又看到新智元的一篇文章:【不用4个H100!340亿参数Code Llama在Mac可跑,每秒20个token,代码生成最拿手|Karpathy转赞】,仔细看了一下,好家伙,竟然就是上面那篇推文介绍的东西。
因此,打算从代码角度,好好研究一下投机性解码。
投机性解码原理
简单的来说:
-
就是使用一个小模型来做草稿,然后使用大模型做纠正检查。 -
小模型的参数量要远小于原模型参数量一个级别才效果明显。 -
小模型和原模型的tokenizer最好一模一样,不然会增加额外的解码、编码时间。
具体流程如下:
假设,让小模型预测5步
第一步
原始token序列为蓝色序列
第二步
使用辅助模型生成5个新的token序列(红色)
流程是这样的
第三步
将蓝色序列和红色序列拼接在一起,放入原始模型中,并且生成一个绿色结果
第四步
考虑到大模型的最后一步都是使用矩阵计算概率,那么在第三步中,看似生成了一个绿色结果。实际上,利用了矩阵计算的并行性:一步计算,就可以验证小模型生成的5个结果对不对。
第五步
看第四步生成的结果,可以发现:
-
小模型生成的第一个token是466,但是大模型生成的第一个token是651。 -
小模型生成的序列中,只要有一个错了,那后面就不能要了。 -
因此不用小模型结果,使用大模型的第一个结果651(绿色小块)作为本轮的结果。 -
以此循环,直到遇到结束符才停止。
第四步扩充
可能还是有人没搞懂:明明大模型利用蓝色和红色块推理得到了绿色338,但是没怎么用到338。而且怎么一次性把绿色块的651、428、287、475、340计算出来的。
这个疑惑的产生,主要是因为没有把casualLM类型的模型搞懂。casualLM类型的模型,都是预测下一个token的概率。而且使用的是矩阵计算,所以一次性把所有token对应的下一个位置,都预测出来了。
第四步里面的绿色的651是怎么产生的:
第四步里面的绿色的428是怎么生成的:
第四步里面的绿色的287是怎么生成的:
第四步里面的绿色的475是怎么生成的:
第四步里面的绿色的340怎么生成的:
第四步里面的绿色的338是怎么生成的:
这个时候,你在回看第四步的图(这里复制了一份),发现:
-
小模型预测的第一个tokenid就已经错了(红色的466),大模型预测的是(绿色的651)。 -
这种自回归序列模型,一步错,步步错。因此,虽然后面的小模型预测了一个340(红色)和大模型预测的340(绿色是一样的),但是完全不能用。 -
考虑到大模型基于原始的token list(蓝色),预测了651(绿色)。那就把651拿来用了。 -
虽然小模型预测了5步,全都错了,但是也不亏。因为小模型的5步的计算时间,远远小于大模型一次预测的时间。
解释
可以想象一下:
-
大模型直接利用原始token(蓝色序列),要预测新的5个token,计算起码需要跑5次。 -
先使用小模型预测5个试一试,然后大模型借助矩阵计算的并行特性,一次性就可以验证这5个中,前几个都是对的。 -
如果有对的,那节约的时间可不是一点点(因为小模型远小于大模型,所以小模型消耗的时间基本可以忽略不记)。 -
这个思想很简单,举个例子:树上全是枣子,旁边又有竹竿,那你肯定拿起竹竿,在空中挥了一挥。能打到枣子,算走运了,没打到枣子,也不亏。那么这个投机生成也是这个道理。
代码部分
-
虽然这个逻辑很简单,而且也说了,是使用了矩阵的并行性,但是在刚开始,我并不知道他在代码里面是如何实现的。 -
虽然投机采样也就是在8月31号才火起来,但是代码在4月份的时候,就已经在transformers包里面实现了。
代码链接为:https://github.com/huggingface/transformers/blob/4b796978656e461177a83d58ec3c2b06152c63db/src/transformers/generation/utils.py#L4269
如何使用
在huggingface的transformers包里面,已经给到一个使用案例了,代码如下:
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LogitsProcessorList,
MinLengthLogitsProcessor,
StoppingCriteriaList,
MaxLengthCriteria,
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
# set pad_token_id to eos_token_id because GPT2 does not have a PAD token
model.generation_config.pad_token_id = model.generation_config.eos_token_id
input_prompt = "It might be possible to"
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
]
)
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
outputs = model.assisted_decoding(
input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
源码解读
是怎么的出来,是使用矩阵的并行原理,可以在这里看到:
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs:
...
else:
if self.config.is_encoder_decoder:
...
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
)
在这里,把原始的token和辅助模型(在huggingface的代码里面,叫辅助模型,但是和上面的小模型是一回事,叫法不一样)生成的token绑定在一起,然后放入原始模型做推理。
在下面的代码块中,把大模型批量处理的结果提取出来,和辅助模型生成的结果做比对。
# 2.2. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
if len(logits_warper) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
误区
其实,刚开始,还以为是把辅助模型生成的新token逐步复制,从batchsize=1变成batchsize=5。但在debug的时候,解除了我的困惑。
因为数据上显示batchsize=1,一直没有变化,就是取了logits的不同位置。
然后想起来CasualLM类型的模型,最后一层都是nn.Linear结构(将hidden_states转换成logits概率),然后就想到了是使用矩阵的并行原理实现的。
参考链接
-
Niels Rogge的推特链接: https://twitter.com/NielsRogge/status/1697335383166472294 -
新智元的那个文章的知乎链接: https://zhuanlan.zhihu.com/p/653729679 -
huggingface的辅助生产文章链接: https://huggingface.co/blog/zh/assisted-generation
最后
本文图解投机采样的原理,并且介绍了其代码实现,如果有写错的地方,大佬们多多指导~

公众号后台回复“数据集”获取100+深度学习各方向资源整理
极市干货

点击阅读原文进入CV社区
收获更多技术干货

