极市导读
文章总结了降低Transformer模型复杂度的方法,包括Softmax Attention的计算复杂度、稀疏Attention方法等。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
Transformer最重要的特性是Global Interaction,也就是说对于任意两个位置的token(不论它们离的有多远),它们之间都能直接进行信息交互。这个特性解决了传统序列建模中长依赖的问题。
但Transformer也有一个典型问题:它的计算复杂度和空间复杂度均为 , 其中 为序列长度。
因此实际应用中很难将Transformer应用到长序列任务上,如包数万个token的论文阅读、书籍阅读等任务。
解决Transformer计算复杂度的方法多种多样。本文介绍其中最主流、最常见的一些方法。
Note:
-
为简化,本文不单独讨论multi-head的情况。大多数方法都可以平移到到multi-head中。 -
本文主要讨论Transformer的Decoder。通常Encoder和Decoder的唯一区别是Encoder中当前token可以attend到左边和右边的其它token,而Decoder中当前token只能attend到左边token。所以本文介绍的这些方法都可以轻易地扩展到Encoder中。
1. Transformer的计算复杂度
首先来详细说明为什么Transformer的计算复杂度是 。将Transformer中标准的Attention称为Softmax Attention。令 为长度为 的序列, 其维度为 , 。 可看作Softmax Attention的输入。
Softmax Attention首先使用线性变换将输入 变换为Query、Key和Value:
(1)
(2)
(3)
其中 和 都是待训练的参数矩阵; 是 和 的维度; 是 的维度。由此可得 的shape分别为:
(4)
(5)
(6)
在常见的Transformer中, 通常 。因此为了简化符号, 我们假设后文中 , 并且只用符号 (Dimension)。
有了Q、K、V, Softmax Attention(SA)的计算如下:
(7)
容易看到,Softmax Attention的计算主要包含两次矩阵乘法操作。
首先回忆一下矩阵乘法的计算复杂度。 对于矩阵 和 , 它们的矩阵乘法共需要 次乘法运算。可以拿国内线性代数教材使用最多的计算方法来理解:为了计算这两个矩阵的乘积, 需要拿矩阵 的每一行去与矩阵 的每一列做点积。因此总共需要 次点积。每次点积包含 次乘法和 次加法。考虑到加法复杂度远小于乘法, 所以总的计算复杂度就是 。
这个 可以使用两种方法理解:
-
第一种理解方法, 因为加法复杂度远小于乘法, 所以忽略加法, 那么 计算复杂度中的base operator指的是乘法操作。 -
第二种理解方法, 因为 与 的量级一致, 所以 计算复杂度中的base operator 指的是乘加操作 (乘法和加法) 。
回到Transformer的复杂度问题上,前面提到Softmax Attention的计算主要包含两次矩阵乘法操作。
第一次矩阵乘法是 , 结合上文关于矩阵乘法复杂度的结论和这两个矩阵的大小(公式 (4)和公式(5)),可知 的复杂度为 。
第二次矩阵乘法是 sof tmax 的结果与 的乘积。sof tmax 输出的矩阵大小为 , 矩阵 的大小为 (公式(6), 前文假设了 ), 所以这一次矩阵乘法的复杂度为 。
因为这两次矩阵乘法是顺序执行的, 所以总的复杂度为它们各自复杂度之和。因为这两个复杂度相等, 相加只是引入了一个常数项, 所以可以忽略, 因此Softmax Attention总的复杂度就为
当我们只关心复杂度与序列长度 之间的关系时, 可以忽略 并将其写为 。
这就是通常说的Transformer计算复杂度随序列长度呈二次方增长的由来。容易看到,Transformer的空间复杂随序列长度也呈二次方增长,即空间复杂度也为 。
这一节最后,我们用一幅简单的图来说明Softmax Attention中参与每个token的Attention Score计算的其它token的位置(只考虑Decoder)。该图主要是为了与后文的一些其它复杂方法作对比。
这幅图按如下方法理解:行和列都表示位置;蓝色表示当前token,绿色表示参与当前token计算的其它token的位置。
例如,图中有12行,可以看作该示例中序列长度为12。以第二行为例,它表示对于第二个位置的token(蓝色位置,当前token),只有第一个位置的token会参与它Attention Score的计算。这其实就是Transformer中Decoder采用的方式:只能看当前token左边的token。
为了简化表述,后文会使用如下方式来表述:第二行中,第二个token只能attend到第一个token。
同理,在第三行中,第三个token可以attend到第一个和第二个token。
以此类推。
同时,也会采用被动表述。例如,在第二行中,第一个token被attended到第二个token。此时,第一个token也可以被称为attended token。
2. Sparse Attention
再看一次图1中的Softmax Attention,容易看到对于每一个token,它都会attend到它前面的所有token。所以通常说Softmax Attention是密集的(dense)。
与密集相对的就是稀疏(Sparse)了。Sparse Attention的主要思路是减少每个token需要attend的token数量。
比如,Softmax Attention对于每个token都要attend它之前的所有token。那么为了减少计算量,能不能只去attend之前的部分token?
2.1 Factorized Self-Attention (Sparse Transformer)
Paper:Generating Long Sequences with Sparse Transformers (2019)
Key Contribution:提出了两种稀疏Attention方法:Strided Attention和Fixed Attention。这二者均可将Transformer的 复杂度降低至 。
Factorized Self-Attention的一个基础假设是:在Softmax Attention中,真正为目标token提供信息的attended token非常少。
换言之,该假设意味着:对于Softmax Attention,在经softmax得到的Attention Weights中,其中大部分的值都趋于0,只有少数值明显大于0。因此Attention Weight比较稀疏。
论文作者将Transformer用到了图像自回归任务中来表明他们假设的合理性,如图2.1.1所示(图2.1.1不容易懂,看后文解释)。
解释一下图2.1.1。作者们用了128层的Transformer在CIFRA-10上做自回归训练。自回归训练是逐行逐像素来做的。
以图a)中左上方的红色汽车图为例,图中黑色区域(下方)是mask。模型下一步需要去预测mask中的第一个点。所谓第一个点,就是逐行看,看到的一个mask黑色点。图中白色区域是Attention Weights。可以看到,有效的Attention Weights几乎全部集中在当前待预测点周围。所以此时的Attention Weights很像卷积的局部性。同时它也很稀疏,因为Attention Weights在较远的位置几乎全为0。
图2中a、b、c、d是来自不同网络层的Attention Weights。可以看到,虽然Attention Weights表现出的空间规律有所差异,但它们总体上都很稀疏:只有极少部分的位置被有效attend(Attention Weights明显大于0,即图中白色区域)。
基于这种稀疏性,作者们提出了两种Attention方法。
注:针对这篇paper的Attention方法,本文不列具体公式。这是因为,这些方法其实都非常简单,但公式反而繁琐、不直观。
第一种方法称为Strided Attention。它又由两种Attention机制构成,我们把它们分别记为SA1和SA2(原文没有这种命名法,这里只是为了指代方便):
-
SA1: 每个token只能Attend它左边相邻的L个token。 -
SA2:每个token只能Attend它左边部分token,这些attened token用如下方法选出:从自己开始往左边数,每隔L就会有一个token可以attend(参见图2.1.3,比较直观)。
为便于理解,请参见图2.1.2和图2.1.3,我们假设L=3。
图2.1.2中的SA1很容易理解,每一个当前token(每一行的蓝色区域)只能attend到它左边的L个token,图中L=3。图2.1.3中的SA2稍微复杂一点,从自己开始往左边数,每隔L就会有一个token可以attend。比如图2.1.3中最后一行,从当前token(蓝色区域)开始往左边数,相隔L个空格(3个空格)处遇到第一个绿色方块可以attend(最后一行,第8列),然后再往左数L个(3个)空格,遇到第二个绿色方块可以attend(最后一行,第4列),以此类推。
Strided Attention的SA1方法和SA2方法的本质是在选择哪些token可以attend。
然后我们来看这两种Attention方法怎么用在Transformer结构中。有三种方法:
-
交替使用。 在第1个Transformer Block中使用SA1,然后在第2个Transformer Block中使用SA2,然后在第3个Transformer Block中又使用SA1,在第4个Transformer Block中又使用SA2,以此类推。这种方法能work的原因是:虽然SA1只能看左边的L个相邻位置,但可以认为在SA1中,每个token聚合了它左边L个token的信息。因此在SA2,虽然它是跳着L个位置看的,但整体感受野等价于整个序列(因为每个attended token聚合了其左边L个token的信息)。 -
联合使用。将SA1选择的attended token和SA2选择的attended token合在一起使用。这个方法很简单,就是在计算Attention时,首先用SA1去选择一些token,再用SA2去选择一些token,然后计算Attention时只使用选择出的token参与计算即可。 -
多头使用。类似Transformer采用的多头机制,这里每个头可以使用SA1、SA2或Transformer中的Softmax Attention。
然后来看 的选择。只要将 的值设为 , 那么容易看到整个Strided Attention的计算复杂度就是 。虽然这个做法很不自然, 但是它确实能实现 的复杂度。
至此,我们介绍完了Strided Attention。
作者们提出的第二种Attention称为Fixed Attention。Fixed Attention也有两种机制,将它们分别称为FA1和FA2。为了便于理解,需要把这两种机制画到一个图里,如图2.1.4所示。
先看FA2,如图中橙色区域。橙色区域的位置是固定的,即从左往右数,每隔L个位置,选中一个token。
理解了FA2,FA1的选择方式就会容易理解了。对于每个当前token(蓝色),往它左边遍历(绿色),直到遇到第一个FA2选中的token(橙色)。
Fixed Attention的使用方法和上文介绍的Strided Attention的三种方法一致(交替使用、联合使用、多头使用),不再赘述。
作者们的结论:Strided Attention适用于图像、音频;Fixed Attention适用于文本。
理由如下:Strided Attention在attended token的位置上做了强假设:哪些位置的token应该被attened,与当前token位置强相关。作者们认为这种适合图像、音频这类数据。而在文本上这类假设不成立。所以在Fixed Attention中,哪些位置的token应该被attened,与当前token位置无关。
讲的再简单点,图像、音频的局部信息很重要;而文本全局信息更重要。
总结:paper对新手不友好,简单的事情用了公式来解释,非常繁琐。希望本文能比原文容易理解一点。
2.2 Blockwise Self-Attention
Paper:Blockwise Self-Attention for Long Document Understanding (2019)
Key Contribution:通过分块来降低Softmax Attention的计算复杂度,方法简单,且实验效果较好。
前文提到了Transformer的时间复杂度和空间复杂度都为 。Blockwise Self-Attention这篇Paper对空间复杂度做了更细致的分析。
一个模型的Memory Usage主要来自三部分:Model Memory、Optimizer Memory、Activation Memory。按照Transformer模型通常使用的Adam类优化器来看,Optimizer Memory是Model Memory的三倍。这是因为Optimizer Memory需要为每个参数存储梯度、first momentum和second momentum。
Model Memory和Optimizer Memory可以直接计算出来。比如对于Model Memory,可以直接通过模型大小与参数类型(如FP16、FP32、INT8)来推算出精确值。同理,Optimizer Memory也可以精确计算出。而Activation Memory则与具体实现相关。所以在Paper中,作者们用PyTorch的内存分析工具来看训练时总的内存开销,然后减去Model Memory和Optimizer Memory,以此来估算Activation Memory。
作者们以BERT-base为例,分析了Model Memory、Optimizer Memory、Activation Memory三者的占比,其中Activation Memory独占87.6%,属于内存开销最大的部分。画一幅图来总结上面提到的内容(注意图中的memory usage的比例是针对BERT-base而言的):
我们说的空间复杂度 主要指的就是Activation Memory这一部分。因为Model Memory和Optimizer Memory是线性复杂度 。
Blockwise Self-Attention的核心思想非常简单:将一个长度为N的序列,平均分成n个短序列。当原始序列长度N无法被n除尽时,对原始序列进行padding,使它能被除尽。举一个例子来说明Blockwise Self-Attention的计算过程。
假设序列长度 , 每个token的维度为 。在Transformer中, Q、K、V三个矩阵的大小都为 。在Blockwise Self-Attention中, 假设分块数 , 那么每个分块中的序列长度为 。所以输入序列 可以划分为 个子序列: 、 , 它们的大小都为 。同理可以把 ( 同理) 划分成 个子矩阵: 、 , 它们的大小也都为 。在计算Self-Attention时, 每个 会去选择一个 和 来计算: (2.2.1)
在只有一个Attention头的情况下, 选择 和 的方法是:shifting one position。很简单, 选择 和 选择 和 选择 和 。换言之, 始终选下一个 和 ; 当 是最后一个block时, 选择 和 。这个过程可以用取余数的符号写出来, 但看着太繁琐, 所以文字描述了。
多头Attention情况下稍微麻烦一点。我们记序列 为单头Attention情况下每个 对应的 和 的编号 : (2.2.2)
仍以上面的示例为例, 在单头情况下, 的值为: (2.2.3)
它表示, 对应的 的值是2, 对应的 的值是 对应的 的值是1。
在多头情况下, 第 个头的 定义如下: (2.2.4)
例如, 按照上述示例, 第一个头的 为:(2.2.5)
第二个头的 为: (2.2.6)
因为分块数 , 所以需要取余数(注意下标从1开始, 所以余 0 时替换为 即可),得到最终的结果: (2.2.7)
过程其实很简单, 只是写出来稍微麻烦一点。
最后来分析复杂度。由本文第一部分分析Transformer复杂度的结论可知, 公式(2.2.1)中的复杂度为 。因为对每一个分块, 都需要用公式 (2.2.1) 进行计算, 所以总复杂度为: (2.2.8)
这既是计算复杂度,也是空间复杂度。
在原文中, 通常选为2。注意, 在大 计法中一般会忽略掉常数项。所以在这种意义下, Blockwise Self-Attention的复杂度仍为 。
但是大 计法的主要目的是理论分析, 并不为实际工程优化。所以即使在大 意义下复杂度没有变, 但它实际计算量仍然减少了。没有改变的 仍然意味着, Blockwise Self-Attention不能扩展到太大的 上, 这就是大 计法的作用。
具体来看, 当 时, RoBERTa的训练时间由原来的9.7天减少至7.5天。
总结:相比于Sparse Transformer中的Factorized Self-Attention,Blockwise Self-Attention更简单,且从效果上来看,优于Factorized Self-Attention。
2.3 Longformer
paper:Longformer: The Long-Document Transformer (2020)
Key Contribution:设计了多种不同的Local Attention和Global Attention方法。
首先重新看一下Factorized Self-Attention (2.1小节)中的两种Attention方法:Strided Attention和Fixed Attention。在Strided Attention中,又有两种Attention机制,在前文中我们把它们分别称为SA1和SA2(参考图2.1.2和2.1.3)。SA1的作用是Local Interaction,而SA2的作用是Global Interation。类似的,在Fixed Attention中(参考图2.1.4),FA1的作用是Local Interaction,而FA2的作用是Global Interation。
在Factorized Self-Attention中,它主要依靠两类Attention的组合使用来实现长距离依赖,例如SA1+SA2(或FA1+FA2)。
Longformer的核心idea和Factorized Self-Attention很像,只是Longformer中的部分Attention只有Local Interaction,没有Global Interaction。
Longformer一共提出了三种Attention,分别是Sliding Window based Attention(SW-Attention)、Dilated Sliding Window based Attention(DSW-Attention)和Global Attention(G-Attention)。下面分别介绍。
先看Sliding Window based Attention(SW-Attention),它其实和Strided Attention中的SA1完全一样。为了方便大家查看,重新把Strided Attention的SA1图copy一份到此处。
SW-Attention只Attend它左边的L个token。在SW-Attention中,L被称为“窗口大小”,而在Strided Attention中,L被称为“步长(Stride)”,它们本质一样。
实际上我们可以在Transformer中只使用SW-Attention来构建具有Global Interaction的网络。其方法很简单,只需要堆叠多个SW-Attention网络层即可,就如同CNN增大感受野的方式。假设窗口大小为K,一个M层的SW-Attention结构中最顶层的“感受野”大小为KM,如图2.3.2所示。
图中绿色方块表示当前token;蓝色线表示信息流; 每一层是上一层的输入;L假设为2。在第一层中,第一个token和第二个token的信息会 流入第二层中的第三个token。 而在第二层中,第二个token和第三个token的信息会流入下一层中的第四个token,以此类推。在最顶层(第五层),虽然当前token的信息只来自上一层的第四个token和第五个token,但从信息流的角度来看,它也隐含包含第一层中第一个和第二个token的信息。
可以看到,通过堆叠SW-Attention,Transformer也可以像CNN一样增加感受野。但是很容易想到,这种非常“间接”的方法不会有太好效果,就像CNN对长依赖建模的能力比较差一样。
再来看Dilated Sliding Window based Attention(DSW-Attention),它其实和Strided Attention中的SA2完全一样。为了方便大家查看,重新把Strided Attention的SA2图copy一份到此处。
DSW-Attention是“空洞”版的SW-Attention,就像空洞卷积和卷积之间的关系。简单来说,被attended的token不再像SW-Attention中是连续排列的,而是按等间距排列(间距称为“空洞率”,在图2.3.3中为3)。
与SW-Attention类似,通过堆叠多个DSW-Attention也能增大网络的感受野,从而实现Global Interaction。
最后再来看Global Attention(G-Attention)。G-Attention是SW-Attention的改进版,它的主要改动是:在SW-Attention基础上,增加了部分固定位置,使得这些位置的token需要 1)attend到其它所有token;2)被其它位置tokenattend到。如图2.3.4所示。
图中绿色token是SW-Attention会attend到的token。橙色token是在G-Attention中额外选中的token。以第五行的当前token为例(橙色),因为它是被额外选中的token,所以它会attend它左边的所有token。图中用黄色标出了相对于SW-Attention之外的额外被attended的token。此外,其它所有token也需要attend到第五个token,参见图中最后四行中的靠左黄色列。
图中第7行类似,大家可以自行对照图脑补一下这个过程。
在G-Attention中,哪些位置会被额外选中与具体下游任务相关。例如,在分类任务中,[CLS] token会被额外选中(Longformer一文中以RoBERTa为基础,将其中的Attention改为本文提到的Attention中的一种或多种);在问答任务中,所有问题的token都会被额外选中。
此外,G-Attention中有两份不同的QKV,一份用于计算由SW-Attention选中的token(图2.3.4中的绿色token),另一份用于计算由G-Attention额外选中的token(图2.3.4中的黄色token)。
上述提到的三种Attention的复杂度都为 , 因为哪些token会被attend与序列长度 无关。
2.4 Local attention and Memory-compressed attention
Paper: Generating wikipedia by summarizing long sequences (2018)
Key Contribution: 提出了Local Attention和Memory-compressed attention。Local Attention的计算复杂度随序列长度增长呈线性增长;Memory-compressed attention可以将计算复杂度减少固定常数倍(超参控制)。
2.4.1 Local Attention
前文中的2.3节也有一个Local Attention,但与此处的Local Attention方法不同。
此处Local Attention的核心思想是使用一个固定的分块大小n对输入序列进行分块,并限制self-attention的计算只能在各个分块内单独进行,如图2.4.1所示。
在图2.4.1中,每个位置的token只能attend到与它同颜色的其它token。例如图中第五行(红色标注行),它表示在Decoder结构中,对于输入序列中的第5个token的attention模式:第五行的灰色区域表示mask,这些mask表示Decoder结构中不能看到当前token之后的信息;前五个token根据颜色进行分块,每个token只能attend到同分块(同颜色)中的其它token,所以对于当前token而言(第五个token),它只能attend到第四个token和它自己(绿色部分)。
作为对比,标准的self-attention的模式图如下:
标准的Decoder结构中,只有一个限制:所有token都不能attend到当前序列之后的token。
Local Attention与2.2节介绍的Blockwise Self-Attention比较类似,其核心思想都是对输入序列进行分块。Local Attention与Blockwise Self-Attention唯一的区别是:Local Attention将Self-attention的计算限制在组内;而Blockwise Self-Attention将Self-attention的计算限制在组间。
例如,考虑图2.4.1中的最后一行,在最简单的情况下,Blockwise Self-Attention的attention模式为:每个分块的token只能attend到下一个分块(蓝色token只能attend到绿色token;绿色token只能attend到橙色token;橙色token只能attend到黄色token;黄色token只能attend到蓝色token)。
下面分析一下Local Attention的复杂度。Local Attention通常选择一个固定长度的分块大小n(例如 )。假设总的序列长度为 , 那么分块数量为 。每一个分块的复杂度为 个分块的总复杂度为 。因为 为常数项, 所以Local Attention的复杂度随序列长度 呈线性增长 。
但在2.3节中, 曾分析到Blockwise Self-Attention的复杂度是 。为何两个如此相似的方法复杂度却有显著差异?
在Blockwise Self-Attention中, 的含义不是分块大小, 而是分块数量。所以每个分块的大小就为 。那么每个分块的attention计算复杂度就是 。又因一共有 个分块, 所以总复杂度是 。
这里之所以把这两个复杂度拿出来对比,是想说明:小心对待复杂性分析中的常量。不同视角可能会导致不同的分析结果。
复杂度唯一能体现的仅仅是:计算量与变量之间的关系。
在上面的例子中,我们关心的变量是序列长度N,所以直接忽略了常数项n。但如果我们要比较这两个复杂度所对应的计算量时,常数项不能轻易忽略。
2.4.2 Memory-compressed Attention
在通常的基于Transformer的模型中,我们

