摘要
KV缓存(KV Cache)是大语言模型(LLM)高效推理中最关键的技术之一。通过在推理过程中存储和重用注意力机制(Attention Mechanism) 中计算得到的键(Key)和值(Value)向量,KV缓存能够显著减少重复计算,从而极大地提升文本生成速度。本文将深入探讨KV缓存的工作原理,并提供一个从零开始、注重可读性的完整代码实现。我们将涵盖其概念基础、实现细节、性能影响以及实际部署中的优化策略,帮助你全面理解这一核心推理优化技术。
目录
-
1. 概述:为什么需要KV缓存? -
2. 什么是KV缓存?一个直观的理解 -
3. LLM如何生成文本:无缓存与有缓存的对比 -
4. 从零开始实现KV缓存 -
• 4.1 注册缓存缓冲区 -
• 4.2 前向传播与 use_cache标志 -
• 4.3 清理缓存 -
• 4.4 在完整模型中传播 use_cache -
• 4.5 在文本生成中使用缓存 -
5. 简单的性能对比 -
6. KV缓存的优势与劣势 -
7. 优化KV缓存实现 -
• 7.1 预分配内存 -
• 7.2 滑动窗口截断缓存 -
• 7.3 实践中的优化效果 -
8. 结论 -
9. 补充:Qwen3与Llama 3中的KV缓存
1. 概述:为什么需要KV缓存?
在LLM的生产部署中,推理效率至关重要。KV缓存正是一种为提升推理阶段计算效率而设计的关键技术。简而言之,KV缓存存储了注意力机制中的中间键(K)和值(V)计算结果以供后续推理步骤重用,从而带来可观的加速效果。
当然,KV缓存并非没有代价:它增加了代码复杂性,提高了内存占用,并且无法在训练阶段使用。然而,对于生产环境中的LLM应用而言,推理速度的大幅提升通常足以抵消代码复杂性和内存方面的权衡。
2. 什么是KV缓存?一个直观的理解
让我们想象一个LLM正在生成文本。假设模型收到提示词(Prompt):“Time”。正如我们所知,LLM每次生成一个词(或一个token)。接下来的两个文本生成步骤可能如下图所示:
注意,在生成的LLM文本输出中存在一些冗余,如下图所示高亮部分:
当我们实现LLM文本生成函数时,通常只使用每一步生成的最后一个token。然而,上图从概念层面凸显了一个主要的低效之处。如果我们深入到注意力机制(Attention Mechanism) 内部,这种低效性(或冗余)会更加清晰。
下图展示了LLM核心注意力机制计算的一个片段。此处,输入token(“Time”和“flies”)被编码为3维向量(现实中这些向量维度更大)。矩阵 W 是注意力机制中用于将这些输入转换为键(Key)、值(Value)和查询(Query)向量的权重矩阵。
该图高亮了键和值向量相关的计算片段:
如前所述,LLM逐个生成token。假设LLM生成了单词“fast”,那么下一轮的提示词变为“Time flies fast”。如下图所示:
通过比较前两张图,我们可以清楚地看到前两个token的键和值向量是完全相同的,在每一轮后续的文本生成中重新计算它们将是巨大的浪费。
KV缓存的核心思想就是实现一种缓存机制,存储先前生成的键和值向量以供重用,从而帮助我们避免这些不必要的重新计算。
3. LLM如何生成文本:无缓存与有缓存的对比
在理解了上一节的基本概念后,让我们在查看具体代码实现前,再深入一些细节。对于“Time flies fast”这个文本生成过程,如果没有KV缓存,我们可以这样理解:
请注意其中的冗余:token “Time” 和 “flies” 在每个新的生成步骤都被重新计算。KV缓存通过存储和重用先前计算好的键值向量来解决这种低效问题:
-
1. 初始化:模型计算输入token的键值向量并将其缓存。 -
2. 增量生成:对于每个新生成的token,模型仅计算该特定token的键值向量。 -
3. 缓存重用:从缓存中检索先前计算的向量,避免冗余计算。
下表总结了计算和缓存的步骤与状态:
这样做的好处是:“Time”被计算了一次,重用了两次;“flies”被计算了一次,重用了两次。这是一个简短的示例,但可以直观地看出,文本越长,我们能够重用的已计算键值就越多,生成速度的提升也就越显著。
下图并排展示了使用和不使用KV缓存的第三个生成步骤:
因此,如果我们想在代码中实现KV缓存,所需要做的就是照常计算键和值,然后存储它们,以便在下一轮中检索。下一节将通过一个具体的代码示例来说明。
4. 从零开始实现KV缓存
实现KV缓存的方法有很多,主要思想是在每个生成步骤中,我们只计算新生成token的键值张量。
我选择了一种强调代码可读性的简单实现。浏览代码修改部分可能是理解其实现方式的最简单途径。
我已在GitHub上分享了两个文件,它们是两个独立的Python脚本,分别实现了带KV缓存和不带KV缓存的LLM:
-
• gpt_ch04.py:取自我的《从零开始构建大语言模型》一书第3章和第4章的自包含代码,用于实现LLM并运行简单的文本生成函数。 -
• gpt_with_kv_cache.py:与上述相同,但进行了必要的修改以实现KV缓存。
为了阅读与KV缓存相关的代码修改,你可以:
a. 打开 gpt_with_kv_cache.py 文件,查找标记为 # NEW 的新增部分。
b. 使用你喜欢的文件对比工具来比较两个代码文件的差异。
以下是实现细节的简要概述。
4.1 注册缓存缓冲区
在 MultiHeadAttention 类的构造函数中,我们添加两个缓冲区 cache_k 和 cache_v,用于跨步骤保存拼接后的键和值:
self.register_buffer(“cache_k”, None)
self.register_buffer(“cache_v”, None)
(如果你想了解更多关于缓冲区的知识,可以参考我的YouTube视频《理解PyTorch缓冲区》。)
4.2 前向传播与 use_cache 标志
接下来,我们扩展 MultiHeadAttention 类的 forward 方法,使其接受一个 use_cache 参数:
def forward(self, x, use_cache=False):
b, num_tokens, d_in = x.shape
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
values_new = self.W_value(x)
queries = self.W_query(x)
#…
if use_cache:
if self.cache_k is None:
self.cache_k, self.cache_v = keys_new, values_new
else:
self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
keys, values = self.cache_k, self.cache_v
else:
keys, values = keys_new, values_new
这里的存储和检索机制实现了KV缓存的核心思想。
-
• 存储 (Storing):具体来说,通过 if self.cache_k is None: …初始化缓存后,我们分别通过self.cache_k = torch.cat(…)和self.cache_v = torch.cat(…)将新生成的键和值添加到缓存中。 -
• 检索 (Retrieving):然后,通过 keys, values = self.cache_k, self.cache_v从缓存中检索存储的键和值。
这基本上就是KV缓存的核心:一个存储与检索机制。接下来的第3和第4节只是处理一些次要的实现细节。
4.3 清理缓存
在生成文本时,我们必须记得在两次独立的文本生成调用之间重置键值缓冲区。否则,新提示的查询会关注到上一个序列遗留的陈旧键,导致模型依赖不相关的上下文并产生不连贯的输出。为了防止这种情况,我们向 MultiHeadAttention 类添加一个 reset_kv_cache 方法,以便稍后在文本生成调用之间使用:
def reset_cache(self):
self.cache_k, self.cache_v = None, None
4.4 在完整模型中传播 use_cache
在完成对 MultiHeadAttention 类的修改后,我们现在修改 GPTModel 类。首先,在构造函数中添加一个用于跟踪token索引位置的计数器:
self.current_pos = 0
这是一个简单的计数器,用于记录模型在增量生成会话中已经缓存了多少个token。
然后,我们将单行的块调用替换为一个显式循环,将 use_cache 参数传递给每个Transformer块:
def forward(self, in_idx, use_cache=False):
# …
if use_cache:
pos_ids = torch.arange(
self.current_pos, self.current_pos + seq_len,
device=in_idx.device, dtype=torch.long
)
self.current_pos += seq_len
else:
pos_ids = torch.arange(
0, seq_len, device=in_idx.device, dtype=torch.long
)
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
x = tok_embeds + pos_embeds
# …
for blk in self.trf_blocks:
x = blk(x, use_cache=use_cache)
如果我们设置 use_cache=True,上面代码的作用是:从 self.current_pos 开始计数 seq_len 步。然后递增计数器,以便下一次解码调用能从上一次结束的位置继续。
self.current_pos 跟踪的原因是:新的查询必须与已存储的键和值在位置上精确对齐。如果不使用计数器,每个新步骤都会从位置0开始,模型会将新token视为与早期token重叠。(或者,我们也可以通过 offset = block.att.cache_k.shape[1] 来跟踪。)
上述更改还需要对 TransformerBlock 类进行一个小修改,以接受 use_cache 参数:
def forward(self, x, use_cache=False):
# …
self.att(x, use_cache=use_cache)
最后,为了使用方便,我们在 GPTModel 中添加一个模型级的重置方法,以一次性清除所有块的缓存:
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.current_pos = 0
4.5 在文本生成中使用缓存
通过对 GPTModel、TransformerBlock 和 MultiHeadAttention 的修改,最后,我们在一个简单的文本生成函数中使用KV缓存:
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
model.eval()
ctx_len = model.pos_emb.num_embeddings # 最大支持长度,例如 1024
if use_cache:
# 用完整的提示词初始化缓存
model.reset_kv_cache()
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
# a) 选取具有最高对数概率的token
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
# b) 将其附加到运行序列中
idx = torch.cat([idx, next_idx], dim=1)
# c) 仅将新token喂给模型
with torch.no_grad():
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx
请注意,在c)步骤中,我们仅通过 logits = model(next_idx, use_cache=True) 将新token输入模型。如果没有缓存,由于没有存储的键值可供重用,我们需要将整个输入 logits = model(idx[:, -ctx_len:], use_cache=False) 喂给模型。
5. 简单的性能对比
在概念层面介绍了KV缓存之后,一个关键的问题是它在实际中的表现如何。我们可以运行上述两个Python脚本来尝试实现效果,这两个脚本将运行一个124M参数的小型LLM,以生成200个新token(给定一个4-token的提示词“Hello, I am”开头):
pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/requirements.txt
python gpt_ch04.py
python gpt_with_kv_cache.py
在一台搭载M4芯片(CPU)的Mac Mini上,结果如下:
由此可见,即使在一个小型的124M参数模型和较短的200-token序列长度上,我们已经获得了约5倍的加速。(请注意,此实现以代码可读性为优化目标,并未针对CUDA或MPS运行时速度进行优化,后者需要预分配张量而不是重新实例化和拼接它们。)
注意:模型在两种情况下都生成了“乱码”,即看起来像这样的文本:Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous bore ITVEGIN ministriesysics Kle functional recountrictionchangingVirgin embarrassedgl …
这是因为我们尚未训练模型。你可以在训练后的模型上使用KV缓存(但请注意,KV缓存仅用于推理阶段)来生成连贯的文本。这里,我们使用未训练的模型是为了保持代码的简洁。
然而,更重要的是,gpt_ch04.py 和 gpt_with_kv_cache.py 的实现生成了完全相同的文本。这告诉我们KV缓存实现是正确的——在实现过程中很容易出现索引错误,导致结果出现偏差。
6. KV缓存的优势与劣势
随着序列长度的增加,KV缓存的优势和劣势在以下方面变得更加明显:
7. 优化KV缓存实现
虽然我上面的概念性KV缓存实现有助于清晰理解,并且主要面向代码可读性和教育目的,但在实际场景中部署它(特别是对于更大模型和更长序列)需要更仔细的优化。
7.1 预分配内存
持续使用 torch.cat 拼接张量(如前所示)会由于频繁的内存分配和重新分配而导致性能瓶颈。相反,我们可以根据预期的最大序列长度预分配一个足够大的张量。这确保了内存使用的一致性并减少了开销。伪代码如下所示:
# 为键和值预分配内存的示例
max_seq_len = 1024 # 预期最大序列长度
cache_k = torch.zeros(
(batch_size, num_heads, max_seq_len, head_dim), device=device
)
cache_v = torch.zeros(
(batch_size, num_heads, max_seq_len, head_dim), device=device
)
在推理过程中,我们可以直接写入这些预分配张量的切片。
7.2 滑动窗口截断缓存
为了避免GPU内存爆炸,我们可以实现一个带动态截断的滑动窗口(Sliding Window) 方法。通过滑动窗口,我们只在缓存中保留最近的 window_size 个token:
# 滑动窗口缓存实现
window_size = 512
cache_k = cache_k[:, :, -window_size:, :]
cache_v = cache_v[:, :, -window_size:, :]
7.3 实践中的优化效果
你可以在 gpt_with_kv_cache_optimized.py 文件中找到这些优化。
在一台搭载M4芯片(CPU)的Mac Mini上,进行200个token的生成,并将窗口大小设置为LLM的上下文长度(以保证结果相同从而实现公平比较),代码运行时间对比如下:
不幸的是,在CUDA设备上,这种速度优势会消失,因为这是一个非常小的模型,并且设备传输和通信的开销超过了KV缓存对于这个小模型带来的好处。
8. 结论
KV缓存是加速LLM自回归推理的一项基础且强大的技术。通过在生成步骤间存储和重用注意力键值向量,它避免了大量的重复计算,从而带来显著的性能提升,尤其是在CPU和长序列生成场景下。
尽管缓存引入了额外的复杂性和内存考量,但其带来的效率提升通常足以抵消这些权衡,尤其是在生产环境中。
请记住,虽然我在这里优先考虑了代码清晰度和可读性而非极致效率,但实际应用中的实现往往需要深思熟虑的优化,例如预分配内存或应用滑动窗口缓存来有效管理内存增长。从这个意义上说,我希望这篇文章能提供丰富的信息。
欢迎尝试这些技术,并祝编程愉快!
9. 补充:Qwen3与Llama 3中的KV缓存
在我为从零开始实现的Qwen3 (0.6B) 和 Llama 3 (1B) 模型添加KV缓存后,我运行了额外的实验来比较使用和不使用KV缓存的模型运行时。请注意,我选择了上文提到的 torch.cat 方法,而非“优化KV缓存实现”一节中描述的预分配KV缓存张量。由于Llama 3和Qwen3支持非常大的上下文长度(分别为131k和41k个token),预分配的张量会消耗约8 GB的额外内存,这是相当昂贵的。
此外,因为我使用了更节省内存的动态创建张量的 torch.cat 方法,我将KV缓存移到了模型外部,以便使用 torch.compile 编译模型来提升计算效率。
性能对比如下所示。
Qwen3 (0.6B) 模型性能对比
Llama 3 (1B) 模型性能对比
正如我们所看到的,在CPU上,KV缓存带来了最显著的加速效果。而编译(torch.compile)则进一步提升了性能。然而,在GPU上,最佳性能可以通过常规的编译模型实现,这可能是因为我们没有在GPU上预分配张量,而且模型相对较小。这强调了优化策略需要根据目标硬件平台和模型规模进行具体调整。

