大数跨境
0
0

详解FlashAttention:加速计算,节省显存,,IO感知的精确注意力

详解FlashAttention:加速计算,节省显存,,IO感知的精确注意力 极市平台
2023-08-20
1
↑ 点击蓝字 关注极市平台
作者丨回旋托马斯x@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/639228219
编辑丨极市平台

极市导读

 

一篇搞懂FlashAttention,深入分析FlashAttention加速attention计算的全过程,含有大量详细公式推导~ >>加入极市CV技术交流群,走在计算机视觉的最前

1. 前言

GPT3、LLaMA、ChatGLM、BLOOM等大语言模型输入输出的最大序列长度只有2048或4096, 扩展到更长序列的难度在哪里呢? 本质原因是, transformer模型的计算复杂度和空间复杂度都是 的, 其中 为序列长度。具体地, 当输入批次大小为 , 序列长度为 时, 层 transformer模型的计算量为 , 中间激活的显存大小为 , 其中 为注意力头数。可以看到, transformer模型的计算量和储存复杂度随着序列长度 呈二次方增长。 这限制了大语言模型的最大序列长度 的大小。计算量和中间激活见《分析transformer模型的参数量、计算量、中间激活、KV cache》。

最近, GPT4将最大序列长度 扩大到了 , Claude更是将最大序列长度 扩大到了100K, 这 些工作一定采用了一些优化方法来降低原生transformer的复杂度。我们知道, 每个transformer层分为两部分:self-attention块和MLP块。上面计算量中的 项和中间激活中的 项都是self-attention块产生的, 与MLP块无关。FlashAttention[1]提出了一种加速计算, 节省显存, 10 感知的精确注意力, 可以有效地缓解上述问题。Meta推出的开源大模型LLaMA[2], 阿联酋推出的开源大模型Falcon[2]都使用了Flash Attention来加速计算和节省显存。目前, Flash Attention已经集成到了pytorch2.0中, 另外triton、xformer等开源框架也进行了整合实现。

从论文题目《Flashattention: Fast and memory-efficient exact attention with io-awareness》入手,简要总结Flash Attention的优点。

  1. 加快了计算(Fast)。Flash Attention并没有减少计算量FLOPs,而是从IO感知出发,减少了HBM访问次数,从而减少了计算时间。论文中用到了"wall-clock time"这个词,时钟时间不仅包含了GPU运行耗时,还包含了IO读写的阻塞时间。减少HBM访问次数,是通过tiling技术分块和算子融合来实现的。

  2. 节省了显存 (Memory-efficient) 。Flash Attention通过引入统计量, 改变注意力机制的计算 顺序, 避免了实例化注意力矩阵 , 将显存复杂度从 降低到了 。这 个技巧不是首创的, 在[4][5]中也有应用。

  3. 精确注意力(Exact Attention)。不同于稀疏注意力,Flash Attention只是分块计算,而不是近似计算,Flash Attention与原生注意力的结果是完全等价的。

2. 背景知识

2.1 计算受限与内存受限

transformer的核心组件self-attention块的计算复杂度和空间复杂度是序列长度 的二次方,已经有许多近似注意力的方法尝试减少attention的计算和内存要求。例如,稀疏近似和低秩近似的方法,将计算复杂度降低到了序列长度的线性或亚线性,但这些方法并没有得到广泛应用。因为这些方法过于关注FLOPs(浮点数计算次数)的减少,而忽略了IO读写的内存访问开销。在现代GPU中,计算速度已经远超过了显存访问速度,transformer中的大部分计算操作的瓶颈是显存访问。对于显存受限的操作,IO感知是非常重要的,因为显存读写占用了大部分的运行时间。

在计算科学中,给定硬件,找出一个操作的性能瓶颈是非常重要的。一个操作的性能瓶颈有两类:计算受限(math-bound)和内存受限(memory-bound)。接下来从一些基本概念入手,来讨论计算受限和内存受限。

计算带宽(math bandwidth) 指的是处理器每秒钟可以执行的数学计算次数,单位通常是OPS(即operations/second)。如果用浮点数进行计算,单位是FLOPS。例如,A100-40GB SXM的计算带宽为312TFLOPS。

注意区分FLOPS和FLOPs。
FLOPS:全大写,floating point operations per second。每秒钟执行的浮点数操作次数,理解为运算速度,是衡量硬件性能的指标。
FLOPs:s为小写,floating point operation。表示浮点数运算次数,理解为计算量,衡量模型或算法的复杂度。

内存带宽 (memory bandwidth) 指的是处理器每秒钟从内存中读取的数据量,单位是bytes/second。例如,A100-40GB SXM的内存带宽为1555GB/s。

从数学上分析, 浮点数计算次数为 , 内存访问量为 , 计算带宽为 , 内存带宽为 。访问内存花费的时间为 , 计算花费的时间为 。由于可以“边计算上一个, 边读/写下一个", 访问内存和计算的时间可以重叠, 此时总的运行时间为

时, 硬件性能受限于计算带宽, 为计算受限 (math-bound)。反之, 当 $T_{m a t h}<t_{m e="" m}$="" 时,="" 硬件性能受限于内存带宽,="" 为内存受限="" (memory-bound)。<="" p="">

在数学上, 对于计算受限, 下面的不等式成立:

对于内存受限,下面的不等式成立:

通常将 ops/bytes  = byte   定义为算术强度(arithmetic intensity)。给定硬件的情况下, 对于一个模型或算法, 若算术强度  > math  mem  , 则为计算受限; 反之为内存受限。

假如用float16进行计算, A100-40GB SXM的 flops /bytes 。若一个操作的算术强度 types , 此时的性能受限于计算带宽; 反之, 性能受限于内存带宽。

如何计算矩阵乘法的计算量?
对于 , 计算 需要 次乘法操作和 次加法操作, 计算量为 flops。对于 , 计算 需要的计算量为 flops。如何计算一个矩阵的内存大小?对于 , 矩阵共有 个元素。若用float16或bfloat16进行存储, 每个元素需要 2 个 bytes。则矩阵A需要的内存大小为 bytes。

对于注意力中的矩阵乘法, , 计算 。其中, 为每个注意力头的维度。

从下表中可以看到, 即使对于计算量较大的矩阵乘法来说, 当 时, 性能主要是受限于内存带宽的。对于LLaMA和BLOOM模型, 注意力头维度 , 最大序列长度 , 矩阵乘法是计算受限的。

N d ops/bytes 受限类型
256 64 85 <201, memory-bound
2048 64 120 <201, memory-bound
4096 64 124 <201, memory-bound
256 128 171 <201, memory-bound
2048 128 228 >201, math-bound
4096 128 241 >201, math-bound
  • math-bound:性能受限于计算带宽。比如:大矩阵乘法、通道数很大的卷积运算
  • memory-bound:性能受限于内存带宽。逐点运算的操作大多是内存受限的,比如:激活函数、dropout、mask;另外reduction操作也是内存受限的,比如:softmax,batch normalization和layer normalization

对于self-attention块,除了大矩阵乘法是计算受限的,其他操作(计算softmax,dropout,mask)都是内存受限的。尽管近似注意力方法将计算复杂度降低为序列长度的线性或亚线性,由于忽略了内存访问开销,这些近似注意力方法并没有有效减少运行时间(wall-clock time)。Flash Attention则是IO感知的,通过减少内存访问,来计算精确注意力,从而减少运行时间,实现计算加速。

2.2 GPU内存分级

如下图左所示,GPU的内存由多个不同大小和不同读写速度的内存组成。内存越小,读写速度越快。对于A100-40GB来说,内存分级图如下所示。SRAM内存分布在108个流式多处理器上,每个处理器的大小为192K。合计为高带宽内存HBM(High Bandwidth Memory),也就是我们常说的显存,大小为40GB。SRAM的读写速度为19TB/s,而HBM的读写速度只有1.5TB/s,不到SRAM的1/10。上面讲到计算注意力的主要瓶颈是显存访问,因此减少对HBM的读写次数,有效利用更高速的SRAM来进行计算是非常重要的。

2.3 运行模式

GPU有大量的线程来执行某个操作,称为kernel。GPU执行操作的典型方式分为三步:(1)每个kernel将输入数据从低速的HBM中加载到高速的SRAM中;(2)在SRAM中,进行计算;(3)计算完毕后,将计算结果从SRAM中写入到HBM中。

2.4 kernel融合

对于性能受限于内存带宽的操作,进行加速的常用方式就是kernel融合kernel融合的基本思想是:避免反复执行“从HBM中读取输入数据,执行计算,将计算结果写入到HBM中”,将多个操作融合成一个操作,减少读写HBM的次数

需要注意的是,模型训练通常会影响到算子融合的效果,因为为了后向传递计算梯度,通常需要将某些中间结果写入到HBM中。

2.5 safe softmax

对于向量 , 原生 softmax的计算过程如下:

在实际硬件中, 浮点数表示的范围是有限的。对于float32和bfloat16来说, 当 时, 就会变成inf, 发生数据上溢的问题。为了避免发生数值溢出的问题, 保证数值稳定性, 计算时通常会“减去最大值", 称为“safe softmax"。 现在所有的深度学习框架中都采用了"safe softmax"这种计算方式。

在训练语言模型时,通常会采用交叉熵损失函数。交叉樀损失函数等价于先执行log_softmax函数, 再计算负对数似然函数。 在计算log_softmax时, 同样会执行“减去最大值”, 这不仅可以避免数值溢出, 提高数值稳定性; 还可以加快计算速度。

3. 前向传递

3.1 Standard Attention

transformer中注意力机制的计算过程为:

其中, , 其中 是序列长度, 是每个注意力头的维度。输出可以记为 。上面的式子可以拆解为:

在标准注意力实现中, 都要写回到HBM中, 占用了 的内存。通常 , 例如, 对于GPT2, ; 对于GPT3, 。注意力矩阵 需要的内存 远大于 所需要的内存 self-attention 中, 大部分操作都是内存受限的逐点运算, 例如, 对 的mask操作, 的softmax操作, 对 的 dropout操作。这些逐点操作的性能是受限于内存带宽的, 会减慢运行时间。

下图展示了标准注意力的实现过程。标准注意力实现存在两个问题:(1)显存占用多。 实例化了 完整的注意力矩阵 , 导致了 的内存要求。(2) HBM读写次数多,  减慢 了运行时间(wall- clock time)。

3.2 Memory-efficient Attention

在注意力计算过程中, 节省显存的主要挑战是softmax与 的列是耦合的。我们的方法是单独计算softmax的归一化因子, 来实现解耦。这种方法在《Online normalizer calculation for softmax》[4], 《Self-attention Does Not Need Memory》[5]中已经使用过。这种方法避免了实例化完整的注意力矩阵 , 不再需要 的显存占用。然而HBM访问次数仍然是 的, 因此运行时间并没有减少。

为了简化分析, 忽略计算softmax时“减去最大值”的步骤。记 的第 列为 的第 列为 ,有 。定义softmax的归一化因子为:

的第 个列向量, 则输出 的第 个列向量 为:

在计算得到归一化因子 后, 就可以通过反复累加 来得到 。节省内存 (memoryefficient)的注意力机制, 改变了计算顺序, 从而避免了实例化完整的注意力矩阵 , 达到了 节省显存的效果。相比于标准注意力机制, 节省显存的注意力机制将显存复杂度从 降低到了 《Self-attention Does Not Need Memory》[4]将这种方法称为“lazy softmax"。然而这种方法HBM访问次数仍然是 的, 因此运行时间并没有减少。

3.3 Flash Attention

在标准注意力实现中,注意力的性能主要受限于内存带宽,是内存受限的。频繁地从HBM中读写N\times NN\times N 的矩阵是影响性能的主要瓶颈。稀疏近似和低秩近似等近似注意力方法虽然减少了计算量FLOPs,但对于内存受限的操作,运行时间的瓶颈是从HBM中读写数据的耗时,减少计算量并不能有效地减少运行时间(wall-clock time)。针对内存受限的标准注意力,Flash Attention是IO感知的,目标是避免频繁地从HBM中读写数据

3.3.1 tiling,分块计算

从GPU显存分级来看,SRAM的读写速度比HBM高一个数量级,但内存大小要小很多。通过kernel融合的方式,将多个操作融合为一个操作,利用高速的SRAM进行计算,可以减少读写HBM的次数,从而有效减少内存受限操作的运行时间。但SRAM的内存大小有限,不可能一次性计算完整的注意力,因此必须进行分块计算,使得分块计算需要的内存不超过SRAM的大小。

为什么要进行分块计算呢? 内存受限 --> 减少HBM读写次数 --> kernel融合 --> 满足SRAM的内存大小 --> 分块计算。因此分块大小block_size不能太大,否则会导致OOM。

分块计算的难点是什么呢? 注意力机制的计算过程是“矩阵乘法 --> scale --> mask --> softmax --> dropout --> 矩阵乘法”,矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的,难点在于softmax的分块计算。由于计算softmax的归一化因子(分母)时,需要获取到完整的输入数据,进行分块计算的难度比较大。论文中也是重点对softmax的分块计算进行了阐述。

tiling的主要思想是分块计算注意力。分块计算的难点在于softmax的分块计算, softmax与 的列是耦合的, 通过引入了两个额外的统计量 来进行解耦, 实现了分块计算。为了保证数值稳定性, 对于 ,执行“减去最大值”的safe softmax的计算过程如下:

对于两个向量 , 解耦拼接向量 的softmax计算:

通过保持两个额外的统计量 , 可以实现softmax的分块计算。需要注意的是, 可以利用GPU多线程同时并行计算多个block的softmax。为了充分利用硬件性能, 多个block的计算不是串行 (sequential) 的, 而是并行的。

下面通过例子说明如何分块计算softmax。对向量 计算softmax, 分成两块 进行计算。计算block 1:

计算block 2:

合并得到完整的softmax结果:

在忽略mask和dropout的情况下,简化分析, Flash Attention算法的前向计算过程如下所示。从下图可以看到, 该算法在 的维度上做外循环, 在 的维度上做内循环。而在triton的代码实现中, 则采用了在 的维度上做外循环, 在 的维度上做内循环。

3.3.2 重计算

上文讲到,模型训练会影响kernel融合的效果,为了后向传递计算梯度,前向计算时通常需要将某些中间结果写回到HBM中,这会产生额外的HBM读写次数,减慢运行时间。因此,Flash Attention没有为后向传递保存很大的中间结果矩阵。

在标准注意力实现中, 后向传递计算 的梯度时, 需要用到 的中间矩阵 , 但这两个矩阵并没有保存下来。这里的技巧是重计算, 保存了两个统计量 , 后向传递时在高速的SRAM上快速地重新计算Attention, 通过分块的方式重新计算注意力矩阵 。相比于标准注意力中, 从HBM中读取很大的中间注意力矩阵的方法, 重计算的方法要快得多。

总的来说, Flash Attention通过调整注意力的计算顺序, 引入两个额外的统计量进行分块计算, 避免了实例化完整的 的注意力矩阵 , 将显存复杂度从 降低到了 。另外, 对于内存受限的标准注意力, Flash Attention通过kernel融合和分块计算, 大量减少了HBM 访问次数, 尽管由于后向传递中的重计算增加了额外的计算量FLOPs, 减少了运行时间, 计算速度更快(GPT2的7.6倍)。

3.3.3 kernel融合

为了简化分析, 上文介绍注意力时忽略了 mask和dropout操作。下面详细介绍Flash Attention前向传递的细节。给定输入 , 计算得到注意力输出

其中, 是softmax的缩放因子, 典型的比如 。MASK操作将输入中的某些元素置为 , 计算softmax后就变成了0, 其他元素保持不变; causal-Im结构和prefix-Im结构的主要差别就是 MASK矩阵不同。 逐点作用在 的每个元素上, 以 的概率将该元素置为 0 , 以 的概率将元素置为

tiling分块计算使得我们可以用一个CUDA kernel来执行注意力的所有操作。从HBM中加载输入数据,在SRAM中执行所有的计算操作(矩阵乘法,mask,softmax,dropout,矩阵乘法),再将计算结果写回到HBM中。通过kernel融合将多个操作融合为一个操作,避免了反复地从HBM中读写数据。kernel融合如下图所示,图片来源于https://www.bilibili.com/video/BV1Zz4y1q7FX/。

考虑mask和dropout操作,完整Flash Attention算法的前向计算过程如下所示:

4. 后向传递

在标准注意力实现中, 后向传递计算 的梯度时, 需要用到 的中间矩阵 。Flash Attention没有保存这两个矩阵, 而是保存了两个统计量 , 在后向传递时进行重计算。

在反向传递过程中, 需要计算损失函数 的梯度。在给定 的情况下, 计算梯度 。其中, 分别表示

梯度 是容易计算的。由 , 基于矩阵求导算法和链式法则, 得到矩阵形式的梯度 。在元素形式上, 有:

之前已经计算好 , 就可以通过反复累加的方式计算得到 。梯度 的计算是略微复杂的。首先要计算 。由 , 得到矩阵形式的梯度 。在元素形式上, 有:

。基于 的雅各比矩阵为 。可以得到:

其中 ○表示逐点相乘。

可以定义:

将该定义代回到上式中, 可以得到:

因此,梯度 可以表示为以下形式:

在计算得到 后, 可以计算 。有前向计算公式 , 可以得到:

与前向计算类似, 在计算得到 后, 就可以通过反复累加的方式计算得到 。避免了实例化矩阵 , 节省了显存, 后向传递的显存复杂度为

5. 性能分析

5.1 计算量与显存占用

Flash Attention的计算量FLOPs主要来源于矩阵乘法。

从算法1来看, 在内循环中, 第9行, 计算 , 其中, 。计算量为 , 即 FLOPs。

在第12行, 计算 , 其中 , 计算量为 , 即 FLOPs。共执行了 次内循环。总的计算量为:

关于额外的显存占用, 需要 的显存来保存统计量

5.2 IO复杂度

先分析标准注意力实现的HBM访问次数

第一步, 从HBM中读取 , 计算 , 将计算结果 写入到 HBM中。HBM的访问次数为

第二步: 从HBM中读取 , 计算 , 将计算结果 写入到HBM中。HBM的访问次数为

第三步: 从HBM中读取 以及 , 计算 , 将计算结果 写入到HBM中。HBM的访问次数为

总的来说, 标准注意力实现的HBM访问次数为

基于算法1,分析flash attention的HBM访问次数

从第6行来看, 中的每个block都只加载了一次, HBM的访问次数为

从第8行来看, 对于每个 的block, 完整的遍历了 , 从HBM中读取了 , 向HBM中写入了 。外循环共有 次。因此, HBM访问次数为

总的来说, flash attention的HBM访问次数为

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读8.7k
粉丝0
内容8.2k