极市导读
一篇搞懂FlashAttention,深入分析FlashAttention加速attention计算的全过程,含有大量详细公式推导~ >>加入极市CV技术交流群,走在计算机视觉的最前
1. 前言
GPT3、LLaMA、ChatGLM、BLOOM等大语言模型输入输出的最大序列长度只有2048或4096, 扩展到更长序列的难度在哪里呢? 本质原因是, transformer模型的计算复杂度和空间复杂度都是 的, 其中 为序列长度。具体地, 当输入批次大小为 , 序列长度为 时, 层 transformer模型的计算量为 , 中间激活的显存大小为 , 其中 为注意力头数。可以看到, transformer模型的计算量和储存复杂度随着序列长度 呈二次方增长。 这限制了大语言模型的最大序列长度 的大小。计算量和中间激活见《分析transformer模型的参数量、计算量、中间激活、KV cache》。
最近, GPT4将最大序列长度 扩大到了 , Claude更是将最大序列长度 扩大到了100K, 这 些工作一定采用了一些优化方法来降低原生transformer的复杂度。我们知道, 每个transformer层分为两部分:self-attention块和MLP块。上面计算量中的 项和中间激活中的 项都是self-attention块产生的, 与MLP块无关。FlashAttention[1]提出了一种加速计算, 节省显存, 10 感知的精确注意力, 可以有效地缓解上述问题。Meta推出的开源大模型LLaMA[2], 阿联酋推出的开源大模型Falcon[2]都使用了Flash Attention来加速计算和节省显存。目前, Flash Attention已经集成到了pytorch2.0中, 另外triton、xformer等开源框架也进行了整合实现。
从论文题目《Flashattention: Fast and memory-efficient exact attention with io-awareness》入手,简要总结Flash Attention的优点。
-
加快了计算(Fast)。Flash Attention并没有减少计算量FLOPs,而是从IO感知出发,减少了HBM访问次数,从而减少了计算时间。论文中用到了"wall-clock time"这个词,时钟时间不仅包含了GPU运行耗时,还包含了IO读写的阻塞时间。减少HBM访问次数,是通过tiling技术分块和算子融合来实现的。
-
节省了显存 (Memory-efficient) 。Flash Attention通过引入统计量, 改变注意力机制的计算 顺序, 避免了实例化注意力矩阵 , 将显存复杂度从 降低到了 。这 个技巧不是首创的, 在[4][5]中也有应用。
-
精确注意力(Exact Attention)。不同于稀疏注意力,Flash Attention只是分块计算,而不是近似计算,Flash Attention与原生注意力的结果是完全等价的。
2. 背景知识
2.1 计算受限与内存受限
transformer的核心组件self-attention块的计算复杂度和空间复杂度是序列长度 的二次方,已经有许多近似注意力的方法尝试减少attention的计算和内存要求。例如,稀疏近似和低秩近似的方法,将计算复杂度降低到了序列长度的线性或亚线性,但这些方法并没有得到广泛应用。因为这些方法过于关注FLOPs(浮点数计算次数)的减少,而忽略了IO读写的内存访问开销。在现代GPU中,计算速度已经远超过了显存访问速度,transformer中的大部分计算操作的瓶颈是显存访问。对于显存受限的操作,IO感知是非常重要的,因为显存读写占用了大部分的运行时间。
在计算科学中,给定硬件,找出一个操作的性能瓶颈是非常重要的。一个操作的性能瓶颈有两类:计算受限(math-bound)和内存受限(memory-bound)。接下来从一些基本概念入手,来讨论计算受限和内存受限。
计算带宽(math bandwidth) 指的是处理器每秒钟可以执行的数学计算次数,单位通常是OPS(即operations/second)。如果用浮点数进行计算,单位是FLOPS。例如,A100-40GB SXM的计算带宽为312TFLOPS。
注意区分FLOPS和FLOPs。
FLOPS:全大写,floating point operations per second。每秒钟执行的浮点数操作次数,理解为运算速度,是衡量硬件性能的指标。
FLOPs:s为小写,floating point operation。表示浮点数运算次数,理解为计算量,衡量模型或算法的复杂度。
内存带宽 (memory bandwidth) 指的是处理器每秒钟从内存中读取的数据量,单位是bytes/second。例如,A100-40GB SXM的内存带宽为1555GB/s。
从数学上分析, 浮点数计算次数为 , 内存访问量为 , 计算带宽为 , 内存带宽为 。访问内存花费的时间为 , 计算花费的时间为 。由于可以“边计算上一个, 边读/写下一个", 访问内存和计算的时间可以重叠, 此时总的运行时间为 。
当 时, 硬件性能受限于计算带宽, 为计算受限 (math-bound)。反之, 当 $T_{m a t h}<t_{m e="" m}$="" 时,="" 硬件性能受限于内存带宽,="" 为内存受限="" (memory-bound)。<="" p="">
在数学上, 对于计算受限, 下面的不等式成立:
对于内存受限,下面的不等式成立:

假如用float16进行计算, A100-40GB SXM的 flops /bytes 。若一个操作的算术强度 types , 此时的性能受限于计算带宽; 反之, 性能受限于内存带宽。
如何计算矩阵乘法的计算量?
对于 , 计算 需要 次乘法操作和 次加法操作, 计算量为 flops。对于 , 计算 需要的计算量为 flops。如何计算一个矩阵的内存大小?对于 , 矩阵共有 个元素。若用float16或bfloat16进行存储, 每个元素需要 2 个 bytes。则矩阵A需要的内存大小为 bytes。
对于注意力中的矩阵乘法, , 计算 。其中, 为每个注意力头的维度。
从下表中可以看到, 即使对于计算量较大的矩阵乘法来说, 当 时, 性能主要是受限于内存带宽的。对于LLaMA和BLOOM模型, 注意力头维度 , 最大序列长度 , 矩阵乘法是计算受限的。
| N | d | ops/bytes | 受限类型 |
|---|---|---|---|
| 256 | 64 | 85 | <201, memory-bound |
| 2048 | 64 | 120 | <201, memory-bound |
| 4096 | 64 | 124 | <201, memory-bound |
| 256 | 128 | 171 | <201, memory-bound |
| 2048 | 128 | 228 | >201, math-bound |
| 4096 | 128 | 241 | >201, math-bound |
-
math-bound:性能受限于计算带宽。比如:大矩阵乘法、通道数很大的卷积运算。 -
memory-bound:性能受限于内存带宽。逐点运算的操作大多是内存受限的,比如:激活函数、dropout、mask;另外reduction操作也是内存受限的,比如:softmax,batch normalization和layer normalization。
对于self-attention块,除了大矩阵乘法是计算受限的,其他操作(计算softmax,dropout,mask)都是内存受限的。尽管近似注意力方法将计算复杂度降低为序列长度的线性或亚线性,由于忽略了内存访问开销,这些近似注意力方法并没有有效减少运行时间(wall-clock time)。Flash Attention则是IO感知的,通过减少内存访问,来计算精确注意力,从而减少运行时间,实现计算加速。
2.2 GPU内存分级
如下图左所示,GPU的内存由多个不同大小和不同读写速度的内存组成。内存越小,读写速度越快。对于A100-40GB来说,内存分级图如下所示。SRAM内存分布在108个流式多处理器上,每个处理器的大小为192K。合计为 。高带宽内存HBM(High Bandwidth Memory),也就是我们常说的显存,大小为40GB。SRAM的读写速度为19TB/s,而HBM的读写速度只有1.5TB/s,不到SRAM的1/10。上面讲到计算注意力的主要瓶颈是显存访问,因此减少对HBM的读写次数,有效利用更高速的SRAM来进行计算是非常重要的。
2.3 运行模式
GPU有大量的线程来执行某个操作,称为kernel。GPU执行操作的典型方式分为三步:(1)每个kernel将输入数据从低速的HBM中加载到高速的SRAM中;(2)在SRAM中,进行计算;(3)计算完毕后,将计算结果从SRAM中写入到HBM中。
2.4 kernel融合
对于性能受限于内存带宽的操作,进行加速的常用方式就是kernel融合。kernel融合的基本思想是:避免反复执行“从HBM中读取输入数据,执行计算,将计算结果写入到HBM中”,将多个操作融合成一个操作,减少读写HBM的次数。
需要注意的是,模型训练通常会影响到算子融合的效果,因为为了后向传递计算梯度,通常需要将某些中间结果写入到HBM中。
2.5 safe softmax
对于向量 , 原生 softmax的计算过程如下:
在实际硬件中, 浮点数表示的范围是有限的。对于float32和bfloat16来说, 当 时, 就会变成inf, 发生数据上溢的问题。为了避免发生数值溢出的问题, 保证数值稳定性, 计算时通常会“减去最大值", 称为“safe softmax"。 现在所有的深度学习框架中都采用了"safe softmax"这种计算方式。
在训练语言模型时,通常会采用交叉熵损失函数。交叉樀损失函数等价于先执行log_softmax函数, 再计算负对数似然函数。 在计算log_softmax时, 同样会执行“减去最大值”, 这不仅可以避免数值溢出, 提高数值稳定性; 还可以加快计算速度。
3. 前向传递
3.1 Standard Attention
transformer中注意力机制的计算过程为:
其中, , 其中 是序列长度, 是每个注意力头的维度。输出可以记为 。上面的式子可以拆解为:
在标准注意力实现中, 都要写回到HBM中, 占用了 的内存。通常 , 例如, 对于GPT2, ; 对于GPT3, 。注意力矩阵 需要的内存 远大于 所需要的内存 。self-attention 中, 大部分操作都是内存受限的逐点运算, 例如, 对 的mask操作, 的softmax操作, 对 的 dropout操作。这些逐点操作的性能是受限于内存带宽的, 会减慢运行时间。
下图展示了标准注意力的实现过程。标准注意力实现存在两个问题:(1)显存占用多。 实例化了 完整的注意力矩阵 , 导致了 的内存要求。(2) HBM读写次数多, 减慢 了运行时间(wall- clock time)。
3.2 Memory-efficient Attention
在注意力计算过程中, 节省显存的主要挑战是softmax与 的列是耦合的。我们的方法是单独计算softmax的归一化因子, 来实现解耦。这种方法在《Online normalizer calculation for softmax》[4], 《Self-attention Does Not Need Memory》[5]中已经使用过。这种方法避免了实例化完整的注意力矩阵 , 不再需要 的显存占用。然而HBM访问次数仍然是 的, 因此运行时间并没有减少。
为了简化分析, 忽略计算softmax时“减去最大值”的步骤。记 的第 列为 的第 列为 ,有 。定义softmax的归一化因子为:
记 为 的第 个列向量, 则输出 的第 个列向量 为:
在计算得到归一化因子 后, 就可以通过反复累加 来得到 。节省内存 (memoryefficient)的注意力机制, 改变了计算顺序, 从而避免了实例化完整的注意力矩阵 , 达到了 节省显存的效果。相比于标准注意力机制, 节省显存的注意力机制将显存复杂度从 降低到了 。《Self-attention Does Not Need Memory》[4]将这种方法称为“lazy softmax"。然而这种方法HBM访问次数仍然是 的, 因此运行时间并没有减少。
3.3 Flash Attention
在标准注意力实现中,注意力的性能主要受限于内存带宽,是内存受限的。频繁地从HBM中读写N\times NN\times N 的矩阵是影响性能的主要瓶颈。稀疏近似和低秩近似等近似注意力方法虽然减少了计算量FLOPs,但对于内存受限的操作,运行时间的瓶颈是从HBM中读写数据的耗时,减少计算量并不能有效地减少运行时间(wall-clock time)。针对内存受限的标准注意力,Flash Attention是IO感知的,目标是避免频繁地从HBM中读写数据。
3.3.1 tiling,分块计算
从GPU显存分级来看,SRAM的读写速度比HBM高一个数量级,但内存大小要小很多。通过kernel融合的方式,将多个操作融合为一个操作,利用高速的SRAM进行计算,可以减少读写HBM的次数,从而有效减少内存受限操作的运行时间。但SRAM的内存大小有限,不可能一次性计算完整的注意力,因此必须进行分块计算,使得分块计算需要的内存不超过SRAM的大小。
为什么要进行分块计算呢? 内存受限 --> 减少HBM读写次数 --> kernel融合 --> 满足SRAM的内存大小 --> 分块计算。因此分块大小block_size不能太大,否则会导致OOM。
分块计算的难点是什么呢? 注意力机制的计算过程是“矩阵乘法 --> scale --> mask --> softmax --> dropout --> 矩阵乘法”,矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的,难点在于softmax的分块计算。由于计算softmax的归一化因子(分母)时,需要获取到完整的输入数据,进行分块计算的难度比较大。论文中也是重点对softmax的分块计算进行了阐述。
tiling的主要思想是分块计算注意力。分块计算的难点在于softmax的分块计算, softmax与 的列是耦合的, 通过引入了两个额外的统计量 来进行解耦, 实现了分块计算。为了保证数值稳定性, 对于 ,执行“减去最大值”的safe softmax的计算过程如下:
对于两个向量 , 解耦拼接向量 的softmax计算:
通过保持两个额外的统计量 , 可以实现softmax的分块计算。需要注意的是, 可以利用GPU多线程同时并行计算多个block的softmax。为了充分利用硬件性能, 多个block的计算不是串行 (sequential) 的, 而是并行的。
下面通过例子说明如何分块计算softmax。对向量 计算softmax, 分成两块 和 进行计算。计算block 1:
计算block 2:
合并得到完整的softmax结果:
在忽略mask和dropout的情况下,简化分析, Flash Attention算法的前向计算过程如下所示。从下图可以看到, 该算法在 的维度上做外循环, 在 的维度上做内循环。而在triton的代码实现中, 则采用了在 的维度上做外循环, 在 的维度上做内循环。
3.3.2 重计算
上文讲到,模型训练会影响kernel融合的效果,为了后向传递计算梯度,前向计算时通常需要将某些中间结果写回到HBM中,这会产生额外的HBM读写次数,减慢运行时间。因此,Flash Attention没有为后向传递保存很大的中间结果矩阵。
在标准注意力实现中, 后向传递计算 的梯度时, 需要用到 的中间矩阵 , 但这两个矩阵并没有保存下来。这里的技巧是重计算, 保存了两个统计量 , 后向传递时在高速的SRAM上快速地重新计算Attention, 通过分块的方式重新计算注意力矩阵 。相比于标准注意力中, 从HBM中读取很大的中间注意力矩阵的方法, 重计算的方法要快得多。
总的来说, Flash Attention通过调整注意力的计算顺序, 引入两个额外的统计量进行分块计算, 避免了实例化完整的 的注意力矩阵 , 将显存复杂度从 降低到了 。另外, 对于内存受限的标准注意力, Flash Attention通过kernel融合和分块计算, 大量减少了HBM 访问次数, 尽管由于后向传递中的重计算增加了额外的计算量FLOPs, 减少了运行时间, 计算速度更快(GPT2的7.6倍)。
3.3.3 kernel融合
为了简化分析, 上文介绍注意力时忽略了 mask和dropout操作。下面详细介绍Flash Attention前向传递的细节。给定输入 , 计算得到注意力输出 。
其中, 是softmax的缩放因子, 典型的比如 。MASK操作将输入中的某些元素置为 , 计算softmax后就变成了0, 其他元素保持不变; causal-Im结构和prefix-Im结构的主要差别就是 MASK矩阵不同。 逐点作用在 的每个元素上, 以 的概率将该元素置为 0 , 以 的概率将元素置为 。
tiling分块计算使得我们可以用一个CUDA kernel来执行注意力的所有操作。从HBM中加载输入数据,在SRAM中执行所有的计算操作(矩阵乘法,mask,softmax,dropout,矩阵乘法),再将计算结果写回到HBM中。通过kernel融合将多个操作融合为一个操作,避免了反复地从HBM中读写数据。kernel融合如下图所示,图片来源于https://www.bilibili.com/video/BV1Zz4y1q7FX/。
考虑mask和dropout操作,完整Flash Attention算法的前向计算过程如下所示:
4. 后向传递
在标准注意力实现中, 后向传递计算 的梯度时, 需要用到 的中间矩阵 。Flash Attention没有保存这两个矩阵, 而是保存了两个统计量 , 在后向传递时进行重计算。
在反向传递过程中, 需要计算损失函数 对 的梯度。在给定 的情况下, 计算梯度 。其中, 分别表示
梯度 是容易计算的。由 , 基于矩阵求导算法和链式法则, 得到矩阵形式的梯度 。在元素形式上, 有:
之前已经计算好 , 就可以通过反复累加的方式计算得到 。梯度 的计算是略微复杂的。首先要计算 。由 , 得到矩阵形式的梯度 。在元素形式上, 有:
有 。基于 的雅各比矩阵为 。可以得到:
其中 ○表示逐点相乘。
可以定义:
将该定义代回到上式中, 可以得到:
因此,梯度 可以表示为以下形式:
在计算得到 后, 可以计算 。有前向计算公式 , 可以得到:
与前向计算类似, 在计算得到 后, 就可以通过反复累加的方式计算得到 。避免了实例化矩阵 , 节省了显存, 后向传递的显存复杂度为 。
5. 性能分析
5.1 计算量与显存占用
Flash Attention的计算量FLOPs主要来源于矩阵乘法。
从算法1来看, 在内循环中, 第9行, 计算 , 其中, 。计算量为 , 即 FLOPs。
在第12行, 计算 , 其中 , 计算量为 , 即 FLOPs。共执行了 次内循环。总的计算量为:
关于额外的显存占用, 需要 的显存来保存统计量 。
5.2 IO复杂度
先分析标准注意力实现的HBM访问次数:
第一步, 从HBM中读取 , 计算 , 将计算结果 写入到 HBM中。HBM的访问次数为 。
第二步: 从HBM中读取 , 计算 , 将计算结果 写入到HBM中。HBM的访问次数为 。
第三步: 从HBM中读取 以及 , 计算 , 将计算结果 写入到HBM中。HBM的访问次数为 。
总的来说, 标准注意力实现的HBM访问次数为 。
基于算法1,分析flash attention的HBM访问次数:
从第6行来看, 中的每个block都只加载了一次, HBM的访问次数为 。
从第8行来看, 对于每个 的block, 完整的遍历了 , 从HBM中读取了 , 向HBM中写入了 。外循环共有 次。因此, HBM访问次数为 。
总的来说, flash attention的HBM访问次数为

