大数跨境
0
0

理解并实现LLM中的KV缓存:从零开始的完整指南

理解并实现LLM中的KV缓存:从零开始的完整指南 ai算法芯片与系统
2025-12-07
4
导读:KV缓存是大语言模型高效推理的关键技术,通过存储和重用注意力机制中的键和值向量,显著减少重复计算,提升文本生成速度。本文从零开始,提供可读性强的完整代码实现,涵盖概念基础、实现细节、性能影响及优化策略

 

摘要

KV缓存(KV Cache)是大语言模型(LLM)高效推理中最关键的技术之一。通过在推理过程中存储和重用注意力机制(Attention Mechanism) 中计算得到的键(Key)和值(Value)向量,KV缓存能够显著减少重复计算,从而极大地提升文本生成速度。本文将深入探讨KV缓存的工作原理,并提供一个从零开始、注重可读性的完整代码实现。我们将涵盖其概念基础、实现细节、性能影响以及实际部署中的优化策略,帮助你全面理解这一核心推理优化技术。

目录

  1. 1. 概述:为什么需要KV缓存?
  2. 2. 什么是KV缓存?一个直观的理解
  3. 3. LLM如何生成文本:无缓存与有缓存的对比
  4. 4. 从零开始实现KV缓存
    • • 4.1 注册缓存缓冲区
    • • 4.2 前向传播与 use_cache 标志
    • • 4.3 清理缓存
    • • 4.4 在完整模型中传播 use_cache
    • • 4.5 在文本生成中使用缓存
  5. 5. 简单的性能对比
  6. 6. KV缓存的优势与劣势
  7. 7. 优化KV缓存实现
    • • 7.1 预分配内存
    • • 7.2 滑动窗口截断缓存
    • • 7.3 实践中的优化效果
  8. 8. 结论
  9. 9. 补充:Qwen3与Llama 3中的KV缓存

1. 概述:为什么需要KV缓存?

在LLM的生产部署中,推理效率至关重要。KV缓存正是一种为提升推理阶段计算效率而设计的关键技术。简而言之,KV缓存存储了注意力机制中的中间键(K)和值(V)计算结果以供后续推理步骤重用,从而带来可观的加速效果。

当然,KV缓存并非没有代价:它增加了代码复杂性,提高了内存占用,并且无法在训练阶段使用。然而,对于生产环境中的LLM应用而言,推理速度的大幅提升通常足以抵消代码复杂性和内存方面的权衡。

2. 什么是KV缓存?一个直观的理解

让我们想象一个LLM正在生成文本。假设模型收到提示词(Prompt):“Time”。正如我们所知,LLM每次生成一个词(或一个token)。接下来的两个文本生成步骤可能如下图所示:

该图展示了LLM如何逐个token生成文本。从提示词“Time”开始,模型生成下一个token“flies”。在下一步中,完整的序列“Time flies”被重新处理以生成token“fast”。

注意,在生成的LLM文本输出中存在一些冗余,如下图所示高亮部分:

此图强调了在每个生成步骤中必须由LLM重新处理的重复上下文(“Time flies”)。由于LLM没有缓存中间键/值状态,它每次生成新token(如“fast”)时都会重新编码整个序列。

当我们实现LLM文本生成函数时,通常只使用每一步生成的最后一个token。然而,上图从概念层面凸显了一个主要的低效之处。如果我们深入到注意力机制(Attention Mechanism) 内部,这种低效性(或冗余)会更加清晰。

下图展示了LLM核心注意力机制计算的一个片段。此处,输入token(“Time”和“flies”)被编码为3维向量(现实中这些向量维度更大)。矩阵 W 是注意力机制中用于将这些输入转换为键(Key)、值(Value)和查询(Query)向量的权重矩阵。

该图高亮了键和值向量相关的计算片段:

此图说明了LLM在注意力计算期间如何从token嵌入中推导出键(k)和值(v)向量。每个输入token(例如“Time”和“flies”)通过学习的矩阵 W_k 和 W_v 进行投影,以获得其对应的键和值向量。

如前所述,LLM逐个生成token。假设LLM生成了单词“fast”,那么下一轮的提示词变为“Time flies fast”。如下图所示:

此图展示了在每个生成步骤中,LLM如何为先前见过的token(“Time”和“flies”)重新计算键和值向量。当生成第三个token(“fast”)时,模型再次重新计算相同的 k(1)/v(1) 和 k(2)/v(2) 向量,而不是复用它们。这种重复计算突显了在自回归解码(Autoregressive Decoding)中不使用KV缓存的低效性。

通过比较前两张图,我们可以清楚地看到前两个token的键和值向量是完全相同的,在每一轮后续的文本生成中重新计算它们将是巨大的浪费。

KV缓存的核心思想就是实现一种缓存机制,存储先前生成的键和值向量以供重用,从而帮助我们避免这些不必要的重新计算。

3. LLM如何生成文本:无缓存与有缓存的对比

在理解了上一节的基本概念后,让我们在查看具体代码实现前,再深入一些细节。对于“Time flies fast”这个文本生成过程,如果没有KV缓存,我们可以这样理解:

无KV缓存的文本生成过程示意图,展示了每个步骤都需要重新计算整个历史序列的键值对,存在明显的计算冗余。

请注意其中的冗余:token “Time” 和 “flies” 在每个新的生成步骤都被重新计算。KV缓存通过存储和重用先前计算好的键值向量来解决这种低效问题:

  1. 1. 初始化:模型计算输入token的键值向量并将其缓存。
  2. 2. 增量生成:对于每个新生成的token,模型仅计算该特定token的键值向量。
  3. 3. 缓存重用:从缓存中检索先前计算的向量,避免冗余计算。

下表总结了计算和缓存的步骤与状态:

KV缓存工作流程总结表,展示了在三个生成步骤中,哪些键值向量是新增计算的,哪些是从缓存中重用的。

这样做的好处是:“Time”被计算了一次,重用了两次;“flies”被计算了一次,重用了两次。这是一个简短的示例,但可以直观地看出,文本越长,我们能够重用的已计算键值就越多,生成速度的提升也就越显著

下图并排展示了使用和不使用KV缓存的第三个生成步骤:

对比有/无KV缓存的文本生成。在上方面板(无缓存)中,每个token步骤都需要重新计算键值向量,导致冗余操作。在下方面板(有缓存)中,先前计算的键值从KV缓存中检索出来,避免了重新计算,从而实现更快生成。

因此,如果我们想在代码中实现KV缓存,所需要做的就是照常计算键和值,然后存储它们,以便在下一轮中检索。下一节将通过一个具体的代码示例来说明。

4. 从零开始实现KV缓存

实现KV缓存的方法有很多,主要思想是在每个生成步骤中,我们只计算新生成token的键值张量

我选择了一种强调代码可读性的简单实现。浏览代码修改部分可能是理解其实现方式的最简单途径。

我已在GitHub上分享了两个文件,它们是两个独立的Python脚本,分别实现了带KV缓存和不带KV缓存的LLM:

  • • gpt_ch04.py:取自我的《从零开始构建大语言模型》一书第3章和第4章的自包含代码,用于实现LLM并运行简单的文本生成函数。
  • • gpt_with_kv_cache.py:与上述相同,但进行了必要的修改以实现KV缓存。

为了阅读与KV缓存相关的代码修改,你可以:
a. 打开 gpt_with_kv_cache.py 文件,查找标记为 # NEW 的新增部分。
b. 使用你喜欢的文件对比工具来比较两个代码文件的差异。

以下是实现细节的简要概述。

4.1 注册缓存缓冲区

在 MultiHeadAttention 类的构造函数中,我们添加两个缓冲区 cache_k 和 cache_v,用于跨步骤保存拼接后的键和值:


   
    
   self.register_buffer(“cache_k”, None)
self
.register_buffer(“cache_v”, None)

(如果你想了解更多关于缓冲区的知识,可以参考我的YouTube视频《理解PyTorch缓冲区》。)

4.2 前向传播与 use_cache 标志

接下来,我们扩展 MultiHeadAttention 类的 forward 方法,使其接受一个 use_cache 参数:


   
    
   def forward(self, x, use_cache=False):
    b, num_tokens, d_in = x.shape

    keys_new = self.W_key(x)  # Shape: (b, num_tokens, d_out)
    values_new = self.W_value(x)
    queries = self.W_query(x)
    #…


    if
 use_cache:
        if
 self.cache_k is None:
            self
.cache_k, self.cache_v = keys_new, values_new
        else
:
            self
.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
            self
.cache_v = torch.cat([self.cache_v, values_new], dim=1)
        keys, values = self.cache_k, self.cache_v
    else
:
        keys, values = keys_new, values_new

这里的存储检索机制实现了KV缓存的核心思想。

  • • 存储 (Storing):具体来说,通过 if self.cache_k is None: … 初始化缓存后,我们分别通过 self.cache_k = torch.cat(…) 和 self.cache_v = torch.cat(…) 将新生成的键和值添加到缓存中。
  • • 检索 (Retrieving):然后,通过 keys, values = self.cache_k, self.cache_v 从缓存中检索存储的键和值。

这基本上就是KV缓存的核心:一个存储与检索机制。接下来的第3和第4节只是处理一些次要的实现细节。

4.3 清理缓存

在生成文本时,我们必须记得在两次独立的文本生成调用之间重置键值缓冲区。否则,新提示的查询会关注到上一个序列遗留的陈旧键,导致模型依赖不相关的上下文并产生不连贯的输出。为了防止这种情况,我们向 MultiHeadAttention 类添加一个 reset_kv_cache 方法,以便稍后在文本生成调用之间使用:


   
    
   def reset_cache(self):
    self
.cache_k, self.cache_v = None, None

4.4 在完整模型中传播 use_cache

在完成对 MultiHeadAttention 类的修改后,我们现在修改 GPTModel 类。首先,在构造函数中添加一个用于跟踪token索引位置的计数器:


   
    
   self.current_pos = 0

这是一个简单的计数器,用于记录模型在增量生成会话中已经缓存了多少个token。

然后,我们将单行的块调用替换为一个显式循环,将 use_cache 参数传递给每个Transformer块:


   
    
   def forward(self, in_idx, use_cache=False):
    # …


    if
 use_cache:
        pos_ids = torch.arange(
            self
.current_pos, self.current_pos + seq_len,
            device=in_idx.device, dtype=torch.long
        )
        self
.current_pos += seq_len
    else
:
        pos_ids = torch.arange(
            0
, seq_len, device=in_idx.device, dtype=torch.long
        )

    pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
    x = tok_embeds + pos_embeds
    # …

    for
 blk in self.trf_blocks:
        x = blk(x, use_cache=use_cache)

如果我们设置 use_cache=True,上面代码的作用是:从 self.current_pos 开始计数 seq_len 步。然后递增计数器,以便下一次解码调用能从上一次结束的位置继续。

self.current_pos 跟踪的原因是:新的查询必须与已存储的键和值在位置上精确对齐。如果不使用计数器,每个新步骤都会从位置0开始,模型会将新token视为与早期token重叠。(或者,我们也可以通过 offset = block.att.cache_k.shape[1] 来跟踪。)

上述更改还需要对 TransformerBlock 类进行一个小修改,以接受 use_cache 参数:


   
    
   def forward(self, x, use_cache=False):
    # …

    self
.att(x, use_cache=use_cache)

最后,为了使用方便,我们在 GPTModel 中添加一个模型级的重置方法,以一次性清除所有块的缓存:


   
    
   def reset_kv_cache(self):
    for
 blk in self.trf_blocks:
        blk.att.reset_cache()
    self
.current_pos = 0

4.5 在文本生成中使用缓存

通过对 GPTModelTransformerBlock 和 MultiHeadAttention 的修改,最后,我们在一个简单的文本生成函数中使用KV缓存:


   
    
   def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
    model.eval()

    ctx_len = model.pos_emb.num_embeddings  # 最大支持长度,例如 1024
    if
 use_cache:
        # 用完整的提示词初始化缓存

        model.reset_kv_cache()
        with
 torch.no_grad():
            logits = model(idx[:, -ctx_len:], use_cache=True)

        for
 _ in range(max_new_tokens):
            # a) 选取具有最高对数概率的token

            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
            # b) 将其附加到运行序列中

            idx = torch.cat([idx, next_idx], dim=1)
            # c) 仅将新token喂给模型

            with
 torch.no_grad():
                logits = model(next_idx, use_cache=True)
    else
:
        for
 _ in range(max_new_tokens):
            with
 torch.no_grad():
                logits = model(idx[:, -ctx_len:], use_cache=False)
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
            idx = torch.cat([idx, next_idx], dim=1)

    return
 idx

请注意,在c)步骤中,我们仅通过 logits = model(next_idx, use_cache=True) 将新token输入模型。如果没有缓存,由于没有存储的键值可供重用,我们需要将整个输入 logits = model(idx[:, -ctx_len:], use_cache=False) 喂给模型。

5. 简单的性能对比

在概念层面介绍了KV缓存之后,一个关键的问题是它在实际中的表现如何。我们可以运行上述两个Python脚本来尝试实现效果,这两个脚本将运行一个124M参数的小型LLM,以生成200个新token(给定一个4-token的提示词“Hello, I am”开头):


   
    
   pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/requirements.txt

python gpt_ch04.py

python gpt_with_kv_cache.py

在一台搭载M4芯片(CPU)的Mac Mini上,结果如下:

性能对比结果图。左侧为无KV缓存,生成200个token耗时约101秒。右侧为有KV缓存,生成200个token耗时约21秒,加速比约为4.8倍。

由此可见,即使在一个小型的124M参数模型和较短的200-token序列长度上,我们已经获得了约5倍的加速。(请注意,此实现以代码可读性为优化目标,并未针对CUDA或MPS运行时速度进行优化,后者需要预分配张量而不是重新实例化和拼接它们。)

注意:模型在两种情况下都生成了“乱码”,即看起来像这样的文本:
Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous bore ITVEGIN ministriesysics Kle functional recountrictionchangingVirgin embarrassedgl …
这是因为我们尚未训练模型。你可以在训练后的模型上使用KV缓存(但请注意,KV缓存仅用于推理阶段)来生成连贯的文本。这里,我们使用未训练的模型是为了保持代码的简洁。

然而,更重要的是,gpt_ch04.py 和 gpt_with_kv_cache.py 的实现生成了完全相同的文本。这告诉我们KV缓存实现是正确的——在实现过程中很容易出现索引错误,导致结果出现偏差。

6. KV缓存的优势与劣势

随着序列长度的增加,KV缓存的优势和劣势在以下方面变得更加明显:

方面
影响
说明
计算效率 显著提升
 (优势)
无缓存时,第 t 步的注意力必须将新查询与 t 个历史键进行比较,
累积工作量按   缩放。
有缓存时,每个键值对仅计算一次然后重用,
将总步长复杂度降低到线性 
内存占用 线性增长 
(劣势)
每个新token都会追加到KV缓存中。
对于长序列和大型LLM,累积的KV缓存会变得非常大,
可能消耗大量甚至无法承受的(GPU)内存。
作为变通方案,我们可以截断KV缓存,但这会增加额外的复杂性。
代码复杂度 增加 
(劣势)
需要管理缓存状态(初始化、更新、清理),
并在前向传播中引入条件逻辑。
适用范围 仅限于推理 
(限制)
KV缓存利用了自回归生成中上下文重复的特性,
在训练阶段(需要完整的并行计算)无法使用。

7. 优化KV缓存实现

虽然我上面的概念性KV缓存实现有助于清晰理解,并且主要面向代码可读性和教育目的,但在实际场景中部署它(特别是对于更大模型和更长序列)需要更仔细的优化。

7.1 预分配内存

持续使用 torch.cat 拼接张量(如前所示)会由于频繁的内存分配和重新分配而导致性能瓶颈。相反,我们可以根据预期的最大序列长度预分配一个足够大的张量。这确保了内存使用的一致性并减少了开销。伪代码如下所示:


   
    
   # 为键和值预分配内存的示例
max_seq_len = 1024  # 预期最大序列长度
cache_k = torch.zeros(
    (batch_size, num_heads, max_seq_len, head_dim), device=device
)
cache_v = torch.zeros(
    (batch_size, num_heads, max_seq_len, head_dim), device=device
)

在推理过程中,我们可以直接写入这些预分配张量的切片。

7.2 滑动窗口截断缓存

为了避免GPU内存爆炸,我们可以实现一个带动态截断的滑动窗口(Sliding Window) 方法。通过滑动窗口,我们只在缓存中保留最近的 window_size 个token:


   
    
   # 滑动窗口缓存实现
window_size = 512
cache_k = cache_k[:, :, -window_size:, :]
cache_v = cache_v[:, :, -window_size:, :]

7.3 实践中的优化效果

你可以在 gpt_with_kv_cache_optimized.py 文件中找到这些优化。

在一台搭载M4芯片(CPU)的Mac Mini上,进行200个token的生成,并将窗口大小设置为LLM的上下文长度(以保证结果相同从而实现公平比较),代码运行时间对比如下:

优化后的性能对比图。左侧为无缓存(101秒),中间为基本KV缓存(21秒),右侧为优化后的KV缓存(19秒),显示进一步的速度提升。

不幸的是,在CUDA设备上,这种速度优势会消失,因为这是一个非常小的模型,并且设备传输和通信的开销超过了KV缓存对于这个小模型带来的好处。

8. 结论

KV缓存是加速LLM自回归推理的一项基础且强大的技术。通过在生成步骤间存储和重用注意力键值向量,它避免了大量的重复计算,从而带来显著的性能提升,尤其是在CPU和长序列生成场景下。

尽管缓存引入了额外的复杂性和内存考量,但其带来的效率提升通常足以抵消这些权衡,尤其是在生产环境中。

请记住,虽然我在这里优先考虑了代码清晰度和可读性而非极致效率,但实际应用中的实现往往需要深思熟虑的优化,例如预分配内存或应用滑动窗口缓存来有效管理内存增长。从这个意义上说,我希望这篇文章能提供丰富的信息。

欢迎尝试这些技术,并祝编程愉快!

9. 补充:Qwen3与Llama 3中的KV缓存

在我为从零开始实现的Qwen3 (0.6B) 和 Llama 3 (1B) 模型添加KV缓存后,我运行了额外的实验来比较使用和不使用KV缓存的模型运行时。请注意,我选择了上文提到的 torch.cat 方法,而非“优化KV缓存实现”一节中描述的预分配KV缓存张量。由于Llama 3和Qwen3支持非常大的上下文长度(分别为131k和41k个token),预分配的张量会消耗约8 GB的额外内存,这是相当昂贵的。

此外,因为我使用了更节省内存的动态创建张量的 torch.cat 方法,我将KV缓存移到了模型外部,以便使用 torch.compile 编译模型来提升计算效率。

性能对比如下所示。

Qwen3 (0.6B) 模型性能对比

Qwen3 0.6B模型在不同配置下的生成时间对比。展示了无缓存、有缓存、编译后无缓存、编译后有缓存等在CPU和GPU上的性能差异。

Llama 3 (1B) 模型性能对比

Llama 3 1B模型在不同配置下的生成时间对比。展示了无缓存、有缓存、编译后无缓存、编译后有缓存等在CPU和GPU上的性能差异。

正如我们所看到的,在CPU上,KV缓存带来了最显著的加速效果。而编译(torch.compile)则进一步提升了性能。然而,在GPU上,最佳性能可以通过常规的编译模型实现,这可能是因为我们没有在GPU上预分配张量,而且模型相对较小。这强调了优化策略需要根据目标硬件平台模型规模进行具体调整。

 


【声明】内容源于网络
0
0
ai算法芯片与系统
长期关注ai领域,算法,芯片,软件(系统,框架,编译器,算子库)等联合设计
内容 196
粉丝 0
ai算法芯片与系统 长期关注ai领域,算法,芯片,软件(系统,框架,编译器,算子库)等联合设计
总阅读172
粉丝0
内容196