大数跨境
0
0

vLLM皇冠上的明珠:深入浅出理解PagedAttention CUDA实现

vLLM皇冠上的明珠:深入浅出理解PagedAttention CUDA实现 极市平台
2024-01-11
0
↑ 点击蓝字 关注极市平台
者丨方佳瑞@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/673284781
编辑丨极市平台

极市导读

 

vLLM的paged attention v1实现继承自FasterTransformers MHA实现,它和FlashAttention的并行任务划分方式不同。其中对KVCache layout的设计比较巧妙,充分利用了shared memory写带宽,是一种常用CUDA编程技巧 >>加入极市CV技术交流群,走在计算机视觉的最前沿

当前, 在大模型推理框架领域,vLLM以其卓越的高吞吐性能和简洁易读的代码而备受瞩目,已经成为许多团队二次开发的首选。其优雅的设计和高效的实现不仅使其在实际应用中表现出色,也使其成为学习和理解推理框架的理想典范。PagedAttention(PA)技术是vLLM的基石,以它为创新点的论文发表在系统顶会SOSP 2023上。

Efficient Memory Management for Large Language Model Serving with PagedAttention
(https//dl.acm.org/doi/abs/10.1145/3600006.3613165)

vLLM中,LLM推理的prefill阶段attention计算使用第三方库xformers的优化实现,decoding阶段attention计算则使用项目编译CUDA代码实现。具体代码在vllm的csrc/attention/attention_kernels.cu文件里,开发者洋洋洒洒写了八百多行CUDA代码。Attention计算时使用页式(paged)管理KVCache用于增加服务吞吐率,但对延迟有负面影响,因此高效的PA实现方法,利用页式内存管理同时尽量降低其负面影响,对框架的综合性能表现至关重要。

本文章将描述PA CUDA Kernel的实现细节,这些细节是公开的论文和博客所不涉及的,但却对框架的速度至关重要。另外,PA实现改编自FasterTransformers某个版本的MHA实现,NV原始版本对GPU特性的运用也是相当老道的,值得大家借鉴。

vLLM中有两个版本PA,使用一个简单的启发式方法来决定是使用V1还是V2版本。V1是本文介绍的版本,改编自FasterTransformers的MHA实现。V2是参考FlashDecoding方式进行实现,对sequence维度进行切分以增加并行粒度,关于FlashDecoding可以参考本人知乎文章。V1适合长度小于8192或者num_seqs * num_heads>512的情况。

阅读本文有很高的门槛,需要读者熟悉CUDA编程模型,并看过vLLM PagedAttention的博客或论文。

笔者学习PA时,笔者曾受益于@zzk again(www.zhihu.com/people/236dedf0a4bc04040b68aea7e2e6e06a)和@SiriusNEO(www.zhihu.com/people/1d602505863ff752e038d17590464590)的文章。

zzk again:PageAttention代码走读(https://zhuanlan.zhihu.com/p/668736097)

SiriusNEO:LLM 高速推理框架 vLLM 源代码分析  vLLM Source Code Analysis(https://zhuanlan.zhihu.com/p/641999400)

知乎上有很多分析PA源码的文章,本文和他们不同在于,这里不是代码阅读笔记,而是CUDA并行算法设计角度描述PA工作流程,达到我来深入你来浅出的目的,让读者可以真正理解PA实现细节,帮助更多人复现和优化PAv1。

1. 参数定义和数据结构

  • num_seq:本次推理请求sequence数目。
  • num_head:Query的head数目。
  • num_kv_heads:Key、Value的head数目,对于MHA和num_head相同,如果是GQA、MQA则num_kv_heads小于num_head。
  • head_size hidden dimension,特征的维度。

PA使用tensor的维度信息:

  • out [num_seqs, num_heads, head_size]
  • Q [num_seqs, num_heads, head_size]
  • KCache [num_blocks, num_kv_heads, head_size/x, block_size, x]:x表示一个向量化的大小,如float16 -> 16 / sizeof(float16) = 8。
  • VCache [num_blocks, num_kv_heads, head_size, block_size]

Paged内存管理相关的辅助数据结构:

  • blk_size:也就是block_size,是KVCache page的最高维,KVCache是若干个page的集合,每个page存(blk_size, num_head,head_size)个K、V的元素。
  • head_mapping [num_heads] 用于MQA, GQA,确定用的KV_head
  • block_tables [num_seqs, max_num_blocks_per_seq] block_tables映射表,表示每个sequence映射到哪几个block上
  • context_lens [num_seqs] 用于变长

2. 课前问题:

如果你能回答以下两个问题,那么说明你已经非常熟练地掌握了PA实现,并可以用批判性的眼光审阅本文,找出其中可能存在的错误。如果你暂时无法回答这些问题,请不要担忧,阅读完本文后会给你答案。

Q1:为什么K Cache的layout和V Cache layout不一样?

Q2:PA实现和FlashAttention有什么区别?

3. PagedAttention算子计算流程:

首先,按照CUDA编程模型对任务进行并行划分,grid大小(num_heads, num_seqs),grid中每个CUDA thread block大小(NUM_THREADS),NUM_THREADS是常量默认为128,也就说每个thread block包含128个线程,负责完成output矩阵一行(包含head_size个元素)结果的attention计算任务。thread block中的线程进一步划分若干个WARP。众所周知,WARP是GPU一个基本的执行单元,由32个线程组成,这些线程以SMIT方式在硬件上同时执行相同的指令,在不同的数据上进行操作。在PA中比较特殊的是,warp内32个线程进一步划分为blk_size个thread group,这和paged KVCache设计x息息相关的,马上会细讲。

Attention计算softmax(QK^T)V,一图胜前言,后面流程介绍将围绕下面这幅图展开。其中thread block, warp, thread group, thread别用不同颜色表示。

图1:PagedAttention CUDA计算流程

在上图的左侧部分,我们看到了Q矩阵,这部分描述了从显存读取Q数据到共享内存的过程。在这个过程中,一个CUDA线程块会读取图中Q矩阵的一行(包含head_size个元素)并将其存入共享内存。这个过程是通过一个循环来实现的,在每次迭代中,每个thread group会读取16字节的Q数据(例如,如果使用float16,那么就是8个元素)。每个warp会读取16*blk_size字节的Q数据,这些数据对应于一个sequence的一个head,由CUDA grid索引指定。当循环访问结束后,共享内存存储Q行的一部分。如下图所示,绿色部分表示存储在一个线程读入共享内存中的数据。

图1中上面部分K矩阵部分描述了从显存读取K Cache到寄存器的过程。每个序列的K Cache包含cxt_length * num_kv_heads * head_size个元素,但由于采用了页式内存管理,这些元素在内存中的存储并不连续。每个thread block只负责计算一个sequence一个head的QK^T,因此只需要ctx_length * head_size个K Cache元素。然而,由于ctx_length维度的存储是不连续的,并且以blk_size个token为粒度分布在不同的内存地址,我们需要根据query的head_idx和seq_idx访问block_table以找到K Cache的physical_block_num。为了方便后续的描述,我们可以将K Cache视为(:, head_size)的形状,其中head_size个元素组成一行。

K Cache的布局为[num_blocks, num_kv_heads, head_size/x, block_size, x],这是为了优化写入shared memory的操作。在Q和K矩阵的同一行元素被读入寄存器并进行点乘运算后,结果需要被存入shared memory。如果一个warp中所有线程都计算Q、K同一行数据,会导致写入shared memory的同一个位置,这将造成warp内不同线程顺序地写入。因此,为了优化,warp的线程最好计算Q和K的不同行数据。因此,在设计K Cache布局时,我们将block_size放在比head_size更低的维度。由于warp size大于block_size,我们需要将head_size拆分为head_size/x和x两个维度,借x到最低维度,以确保每个线程读入的数据量和计算量都足够大。最后,每个线程组派一个线程去写入shared memory,这样一个warp有blk_size个线程并行写入shared memory,从而增加了shared memory的访问带宽。这种设计策略是为了实现高效的并行计算和内存访问,以提高整体的计算性能。

在代码实现中,访问K矩阵需要一个循环,该循环使得CUDA线程块中的所有warp依次访问num_block个页面。在每次循环迭代中,每个warp负责访问连续的blk_size个K Cache行,这涉及到的数据量为blk_size * head_size个元素。同时,每个thread group负责访问K Cache的一行,将head_size个元素加载到自己的寄存器中。接着,寄存器中的Q和K数据元素立即进行点乘运算,运算结果被写入shared memory中。因此,线程块的shared memory存储了一行QK^T的结果,包含ctx_length个元素。这种实现方式充分利用了CUDA的并行计算能力,以提高数据处理的效率。

然后,thread block对shared memory中元素进行max,sum方式reduction,然后计算得到softmax结果。

图1右边V矩阵部分描述从显存读V Cache到寄存器过程。和K Cache一样,CUDA thread block依次访问num_blk个物理块到寄存器,每个warp负责blk_size个token的page内存,page的真实物理地址同样需要进行索引。不过这里不需要以thread group为单位访问16字节,而是每个thread访问16字节的元素。访问完就可以与shared memory的softmax(QK^T)中间结果对应位置16字节的数据进行点乘,得到一个float结果,写到output对应位置中。

为什么V Cache的layout是 [num_blocks, num_kv_heads, head_size, block_size],和K Cache layout不一样?这是因为V要去做点乘的对象在shared memory,只需要读,不涉及并行写的问题。

和FlashAttention(FA)有什么不同?结合我的图和中间FAv2的流程图对比就一目了然了。FA用了两层循环,每次写一个Tile的output tensor,而PA一直只有一层循环,每次写一行output tensor。因为每次都有整行的QK^T中间结果,不需要online softmax这种花哨技巧。

4. PAv1的问题

以我粗浅的理解指出几点vLLM PAv1的问题。一、和MHA相比,MQA和GAQ没有减少对KV Cache的读写次数。读K、V Cache时候只是做了一个head_idx的转换,会重复从显存读相同的head。二、对于seq length很长情况没法适应,因为没有沿着ctx_length或者batch维度做切分。这点FlashAttention和FlashDecoding就做了,因此PAv2借鉴了FA的切分思想。

5. 总结:

vLLM的paged attention v1实现继承自FasterTransformers MHA实现,它和FlashAttention的并行任务划分方式不同。其中对KVCache layout的设计比较巧妙,充分利用了shared memory写带宽,是一种常用CUDA编程技巧。

本文是Attention算子优化宇宙第四篇,对Attention优化感兴趣同学可以一起服用:

方佳瑞:大模型训练加速之FlashAttention系列:爆款工作背后的产品观(https://zhuanlan.zhihu.com/p/664061672)

方佳瑞:大模型推理加速之Flash Decoding:更小子任务提升并行度(https://zhuanlan.zhihu.com/p/664264445)

方佳瑞:大模型推理加速之FlashDecoding++:野生Flash抵达战场(https://zhuanlan.zhihu.com/p/665361668)

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读919
粉丝0
内容8.2k