理解KV缓存、其工作机制以及与原始架构的比较
目录
-
• 引言 -
• 什么是KV缓存? -
• 为什么需要KV缓存? -
• KV缓存的工作机制 -
• 原始Transformer与带KV缓存的Transformer的FLOPs比较 -
• 带KV缓存的Transformer的FLOPs分析 -
• 性能对比与权衡 -
• 实际应用考虑 -
• 结论 -
• 参考文献
引言
Transformer架构已成为自然语言处理领域的核心基础,特别是在大型语言模型(LLM)中。随着模型规模的不断扩大,推理效率成为实际应用中的关键挑战。在本系列文章中,我们将系统探讨Transformer模型的各种优化技术。作为开篇,本文将深入分析KV缓存(Key-Value Cache)这一重要的推理优化技术。
什么是KV缓存?
KV缓存是一种通过存储和重用先前计算的键值对来提升自回归语言模型推理性能的技术。该技术能够在不影响模型准确性的前提下,显著减少推理过程中的计算量并降低端到端延迟。
在标准的自回归生成过程中,每次生成新token时,模型都需要重新处理所有先前的token。KV缓存通过保存这些中间计算结果,避免了重复计算,从而提高了效率。
为什么需要KV缓存?
自回归语言模型(如GPT系列)在生成文本时具有特定的计算模式:每次生成新token都需要基于之前生成的所有token进行计算。这种计算模式导致了显著的效率问题。
考虑以下生成示例:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class Sampler:
def __init__(self, model_name: str = 'gpt2-medium') -> None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to("cpu").to(self.device)
def encode(self, text):
return self.tokenizer.encode(text, return_tensors='pt').to(self.device)
def decode(self, ids):
return self.tokenizer.decode(ids)
def get_next_token_prob(self, input_ids: torch.Tensor):
with torch.no_grad():
logits = self.model(input_ids=input_ids).logits
logits = logits[0, -1, :]
return logits
class GreedySampler(Sampler):
def __call__(self, prompt, max_new_tokens=10):
predictions = []
result = prompt
for i in range(max_new_tokens):
print(f"step {i} input: {result}")
input_ids = self.encode(result)
next_token_probs = self.get_next_token_prob(input_ids=input_ids)
id = torch.argmax(next_token_probs, dim=-1).item()
result += self.decode(id)
predictions.append(next_token_probs[id].item())
return result
gs = GreedySampler()
gs(prompt="Large language models are recent advances in deep learning", max_new_tokens=10)
执行过程输出:
step 0 input: Large language models are recent advances in deep learning
step 1 input: Large language models are recent advances in deep learning,
step 2 input: Large language models are recent advances in deep learning, which
step 3 input: Large language models are recent advances in deep learning, which uses
step 4 input: Large language models are recent advances in deep learning, which uses deep
step 5 input: Large language models are recent advances in deep learning, which uses deep neural
step 6 input: Large language models are recent advances in deep learning, which uses deep neural networks
step 7 input: Large language models are recent advances in deep learning, which uses deep neural networks to
step 8 input: Large language models are recent advances in deep learning, which uses deep neural networks to learn
step 9 input: Large language models are recent advances in deep learning, which uses deep neural networks to learn to
从上述示例可以看出,随着生成过程的进行,输入序列长度不断增加,导致每次推理的计算量呈平方级增长。这种计算模式在生成长文本时尤为低效。
关键问题在于:如果没有KV缓存,每次生成新token时,模型都需要重新计算所有历史token的隐藏状态。这意味着在LLM推理过程中,每个生成步骤都要重复处理整个历史序列,造成了大量的冗余计算。
KV缓存的工作机制
基本原理
KV缓存的核心思想是存储每个Transformer层中注意力机制的键(Key)和值(Value)矩阵。在生成新token时,只需计算当前token的查询(Query)向量,然后将其与之前所有token的缓存键值对进行注意力计算。
具体实现中,每个注意力头维护独立的KV缓存:
class SelfAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# ... 其他初始化代码
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
在前向传播过程中,缓存被逐步填充和访问:
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_complex: torch.Tensor
):
# ... 其他计算
# 输入形状: (B, 1, Dim)
# xk形状: (B, 1, H_KV, Head_Dim)
# xv形状: (B, 1, H_KV, Head_Dim)
# 更新缓存
self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv
# 从缓存中获取所有之前的键值对
# 形状: (B, Seq_Len_KV, H_KV, Head_Dim)
keys = self.cache_k[:batch_size, : start_pos + seq_len]
values = self.cache_v[:batch_size, : start_pos + seq_len]
数学表达
假设在第 个Transformer层生成token 。使用KV缓存的计算分为两个部分:
-
1. KV缓存更新: -
2. 注意力计算:
可视化表示
考虑一个具有12个注意力头和KV缓存的Transformer架构。下图展示了在生成输入序列的第9个token时的状态:
如图所示,在生成新token时,只有当前token需要经过完整的线性变换,而之前的token则直接从缓存中读取其键值表示。在LLM推理过程中,历史状态通过KV缓存得以保存和重用,避免了重复计算。
原始Transformer与带KV缓存的Transformer的FLOPs比较
FLOPs基本概念
FLOPs(浮点运算次数)是衡量计算复杂度的关键指标,表示执行浮点数运算的总次数。对于矩阵乘法,FLOPs的计算方法如下:
设 ,计算 需要:
-
• 乘法运算: 次 -
• 加法运算: 次
总FLOPs为:
基本符号定义
|
|
|
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
原始Transformer的FLOPs分析
原始Transformer在每次生成新token时都需要重新处理整个序列,其计算量分析如下:
1. 自注意力块
步骤1:Q、K、V投影
-
• 输入形状: -
• 权重形状: -
• 每个投影的计算量: -
• Q、K、V总计算量:
步骤2:注意力计算
-
• 计算: -
• : -
• 输出投影 :
步骤3:MLP块
-
• 第一个线性层: -
• 第二个线性层:
步骤4:词汇表投影
-
• 计算量:
单层Transformer总计算量:
层Transformer总计算量:
带KV缓存的Transformer的FLOPs分析
使用KV缓存后,每次生成新token时只需处理当前token,输入形状变为 :
1. 自注意力块
步骤1:Q、K、V投影
-
• 输入形状: -
• Q、K、V总计算量:
步骤2:注意力计算
-
• 计算: -
• : -
• 输出投影 :
步骤3:MLP块
-
• 第一个线性层: -
• 第二个线性层:
步骤4:词汇表投影
-
• 计算量:
单层Transformer总计算量:
层Transformer总计算量:
性能对比与权衡
计算复杂度分析
两种方法的计算复杂度对比如下:
|
|
|
|
|---|---|---|
|
|
||
|
|
||
|
|
||
|
|
内存与计算权衡
KV缓存虽然减少了计算量,但增加了内存开销:
-
• 计算优势:将注意力计算的复杂度从 降低到 -
• 内存代价:需要存储每层的KV缓存,总内存需求为
定量比较
定义两种方法的FLOPs函数:
原始Transformer:
带KV缓存的Transformer:
当序列长度 足够大时,带KV缓存的Transformer在计算效率上具有明显优势:
实际应用考虑
适用场景
KV缓存特别适用于以下场景:
-
1. 自回归文本生成:如对话系统、文本补全等 -
2. 长文本生成:生成文档、长篇文章等 -
3. 实时推理应用:需要低延迟响应的场景
实现注意事项
-
1. 缓存管理:需要有效管理缓存空间,防止内存溢出 -
2. 批处理优化:在批处理推理中,需要处理不同序列长度的缓存 -
3. 内存带宽限制:当缓存较大时,内存带宽可能成为瓶颈
局限性
-
1. 内存占用:KV缓存需要额外的内存空间 -
2. 初始延迟:在生成第一个token时,仍需计算整个输入序列的KV缓存 -
3. 动态序列长度:对于可变长度输入,缓存管理更加复杂
结论
KV缓存是Transformer推理优化中的关键技术,通过空间换时间的策略显著提升了自回归生成的效率。本文通过详细的理论分析和数学推导,展示了KV缓存的工作原理和性能优势。
主要结论:
-
1. 计算效率提升:KV缓存将注意力计算的复杂度从 降低到 ,在生成长序列时优势明显 -
2. 内存权衡:虽然增加了内存开销,但在大多数实际场景中,这种权衡是可接受的 -
3. 历史状态管理:在LLM推理过程中,KV缓存有效地保存和重用历史状态,避免了每次生成新token时的重复计算 -
4. 实际价值:对于部署大型语言模型,KV缓存是提高推理速度和降低延迟的关键技术
在后续文章中,我们将继续探讨其他Transformer优化技术,如量化、蒸馏、算子融合等,为构建高效的推理系统提供全面的技术视角。
参考文献
-
1. Vaswani, A. et al. "Attention is All You Need." NeurIPS 2017. -
2. Pope, R. et al. "Efficiently Scaling Transformer Inference." arXiv:2211.05102. -
3. Hugging Face Transformers Library. https://github.com/huggingface/transformers -
4. Llama Implementation. https://github.com/hkproj/pytorch-llama

