大数跨境
0
0

降低Transformer复杂度O(N^2)的方法汇总

降低Transformer复杂度O(N^2)的方法汇总 极市平台
2023-12-02
2
↑ 点击蓝字 关注极市平台
作者丨Civ@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/634406691
编辑丨极市平台

极市导读

 

文章总结了降低Transformer模型复杂度的方法,包括Softmax Attention的计算复杂度、稀疏Attention方法等。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

Transformer最重要的特性是Global Interaction,也就是说对于任意两个位置的token(不论它们离的有多远),它们之间都能直接进行信息交互。这个特性解决了传统序列建模中长依赖的问题。

但Transformer也有一个典型问题:它的计算复杂度和空间复杂度均为 , 其中 为序列长度。

因此实际应用中很难将Transformer应用到长序列任务上,如包数万个token的论文阅读、书籍阅读等任务。

解决Transformer计算复杂度的方法多种多样。本文介绍其中最主流、最常见的一些方法。

Note:

  • 为简化,本文不单独讨论multi-head的情况。大多数方法都可以平移到到multi-head中。
  • 本文主要讨论Transformer的Decoder。通常Encoder和Decoder的唯一区别是Encoder中当前token可以attend到左边和右边的其它token,而Decoder中当前token只能attend到左边token。所以本文介绍的这些方法都可以轻易地扩展到Encoder中。

1. Transformer的计算复杂度

首先来详细说明为什么Transformer的计算复杂度是 。将Transformer中标准的Attention称为Softmax Attention。令 为长度为 的序列, 其维度为 , 可看作Softmax Attention的输入。

Softmax Attention首先使用线性变换将输入 变换为Query、Key和Value:

(1)

(2)

(3)

其中 都是待训练的参数矩阵; 的维度; 的维度。由此可得 的shape分别为:

(4)

(5)

(6)

在常见的Transformer中, 通常 。因此为了简化符号, 我们假设后文中 , 并且只用符号 (Dimension)。

有了Q、K、V, Softmax Attention(SA)的计算如下:

(7)

容易看到,Softmax Attention的计算主要包含两次矩阵乘法操作。

首先回忆一下矩阵乘法的计算复杂度。 对于矩阵 , 它们的矩阵乘法共需要 次乘法运算。可以拿国内线性代数教材使用最多的计算方法来理解:为了计算这两个矩阵的乘积, 需要拿矩阵 的每一行去与矩阵 的每一列做点积。因此总共需要 次点积。每次点积包含 次乘法和 次加法。考虑到加法复杂度远小于乘法, 所以总的计算复杂度就是

这个 可以使用两种方法理解:

  • 第一种理解方法, 因为加法复杂度远小于乘法, 所以忽略加法, 那么 计算复杂度中的base operator指的是乘法操作。
  • 第二种理解方法, 因为 的量级一致, 所以 计算复杂度中的base operator 指的是乘加操作 (乘法和加法) 。

回到Transformer的复杂度问题上,前面提到Softmax Attention的计算主要包含两次矩阵乘法操作。

第一次矩阵乘法是 , 结合上文关于矩阵乘法复杂度的结论和这两个矩阵的大小(公式 (4)和公式(5)),可知 的复杂度为

第二次矩阵乘法是 sof tmax 的结果与 的乘积。sof tmax 输出的矩阵大小为 , 矩阵 的大小为 (公式(6), 前文假设了 ), 所以这一次矩阵乘法的复杂度为

因为这两次矩阵乘法是顺序执行的, 所以总的复杂度为它们各自复杂度之和。因为这两个复杂度相等, 相加只是引入了一个常数项, 所以可以忽略, 因此Softmax Attention总的复杂度就为

当我们只关心复杂度与序列长度 之间的关系时, 可以忽略 并将其写为

这就是通常说的Transformer计算复杂度随序列长度呈二次方增长的由来。容易看到,Transformer的空间复杂随序列长度也呈二次方增长,即空间复杂度也为

这一节最后,我们用一幅简单的图来说明Softmax Attention中参与每个token的Attention Score计算的其它token的位置(只考虑Decoder)。该图主要是为了与后文的一些其它复杂方法作对比。

图1 Softmax Attention中参与每个token的Attention Score计算的其它token的位置

这幅图按如下方法理解:行和列都表示位置;蓝色表示当前token,绿色表示参与当前token计算的其它token的位置。

例如,图中有12行,可以看作该示例中序列长度为12。以第二行为例,它表示对于第二个位置的token(蓝色位置,当前token),只有第一个位置的token会参与它Attention Score的计算。这其实就是Transformer中Decoder采用的方式:只能看当前token左边的token

为了简化表述,后文会使用如下方式来表述:第二行中,第二个token只能attend到第一个token。

同理,在第三行中,第三个token可以attend到第一个和第二个token。

以此类推。

同时,也会采用被动表述。例如,在第二行中,第一个token被attended到第二个token。此时,第一个token也可以被称为attended token。

2. Sparse Attention

再看一次图1中的Softmax Attention,容易看到对于每一个token,它都会attend到它前面的所有token。所以通常说Softmax Attention是密集的(dense)。

与密集相对的就是稀疏(Sparse)了。Sparse Attention的主要思路是减少每个token需要attend的token数量

比如,Softmax Attention对于每个token都要attend它之前的所有token。那么为了减少计算量,能不能只去attend之前的部分token?

2.1 Factorized Self-Attention (Sparse Transformer)

Paper:Generating Long Sequences with Sparse Transformers (2019)

Key Contribution:提出了两种稀疏Attention方法:Strided Attention和Fixed Attention。这二者均可将Transformer的 复杂度降低至

Factorized Self-Attention的一个基础假设是:在Softmax Attention中,真正为目标token提供信息的attended token非常少

换言之,该假设意味着:对于Softmax Attention,在经softmax得到的Attention Weights中,其中大部分的值都趋于0,只有少数值明显大于0。因此Attention Weight比较稀疏。

论文作者将Transformer用到了图像自回归任务中来表明他们假设的合理性,如图2.1.1所示(图2.1.1不容易懂,看后文解释)。

图2.1.1 Softmax Attention中Weight Vector稀疏性示意图

解释一下图2.1.1。作者们用了128层的Transformer在CIFRA-10上做自回归训练。自回归训练是逐行逐像素来做的。

以图a)中左上方的红色汽车图为例,图中黑色区域(下方)是mask。模型下一步需要去预测mask中的第一个点。所谓第一个点,就是逐行看,看到的一个mask黑色点。图中白色区域是Attention Weights。可以看到,有效的Attention Weights几乎全部集中在当前待预测点周围。所以此时的Attention Weights很像卷积的局部性。同时它也很稀疏,因为Attention Weights在较远的位置几乎全为0。

图2中a、b、c、d是来自不同网络层的Attention Weights。可以看到,虽然Attention Weights表现出的空间规律有所差异,但它们总体上都很稀疏:只有极少部分的位置被有效attend(Attention Weights明显大于0,即图中白色区域)

基于这种稀疏性,作者们提出了两种Attention方法。

注:针对这篇paper的Attention方法,本文不列具体公式。这是因为,这些方法其实都非常简单,但公式反而繁琐、不直观。

第一种方法称为Strided Attention。它又由两种Attention机制构成,我们把它们分别记为SA1SA2(原文没有这种命名法,这里只是为了指代方便):

  • SA1: 每个token只能Attend它左边相邻的L个token。
  • SA2:每个token只能Attend它左边部分token,这些attened token用如下方法选出:从自己开始往左边数,每隔L就会有一个token可以attend(参见图2.1.3,比较直观)。

为便于理解,请参见图2.1.2和图2.1.3,我们假设L=3。

图2.1.2 Strided Attention的SA1。图中每个token只能attend到它左边相邻的L个token,图中L=3。
图2.1.3 Strided Attention的SA2,图中L=3。

图2.1.2中的SA1很容易理解,每一个当前token(每一行的蓝色区域)只能attend到它左边的L个token,图中L=3。图2.1.3中的SA2稍微复杂一点,从自己开始往左边数,每隔L就会有一个token可以attend。比如图2.1.3中最后一行,从当前token(蓝色区域)开始往左边数,相隔L个空格(3个空格)处遇到第一个绿色方块可以attend(最后一行,第8列),然后再往左数L个(3个)空格,遇到第二个绿色方块可以attend(最后一行,第4列),以此类推。

Strided Attention的SA1方法和SA2方法的本质是在选择哪些token可以attend。

然后我们来看这两种Attention方法怎么用在Transformer结构中。有三种方法:

  • 交替使用。 在第1个Transformer Block中使用SA1,然后在第2个Transformer Block中使用SA2,然后在第3个Transformer Block中又使用SA1,在第4个Transformer Block中又使用SA2,以此类推。这种方法能work的原因是:虽然SA1只能看左边的L个相邻位置,但可以认为在SA1中,每个token聚合了它左边L个token的信息。因此在SA2,虽然它是跳着L个位置看的,但整体感受野等价于整个序列(因为每个attended token聚合了其左边L个token的信息)。
  • 联合使用。将SA1选择的attended token和SA2选择的attended token合在一起使用。这个方法很简单,就是在计算Attention时,首先用SA1去选择一些token,再用SA2去选择一些token,然后计算Attention时只使用选择出的token参与计算即可。
  • 多头使用。类似Transformer采用的多头机制,这里每个头可以使用SA1、SA2或Transformer中的Softmax Attention。

然后来看 的选择。只要将 的值设为 , 那么容易看到整个Strided Attention的计算复杂度就是 。虽然这个做法很不自然, 但是它确实能实现 的复杂度。

至此,我们介绍完了Strided Attention。

作者们提出的第二种Attention称为Fixed Attention。Fixed Attention也有两种机制,将它们分别称为FA1FA2。为了便于理解,需要把这两种机制画到一个图里,如图2.1.4所示。

图2.1.4 Fixed Attention中的FA1(绿色)和FA2(橙色),L=3

先看FA2,如图中橙色区域。橙色区域的位置是固定的,即从左往右数,每隔L个位置,选中一个token。

理解了FA2,FA1的选择方式就会容易理解了。对于每个当前token(蓝色),往它左边遍历(绿色),直到遇到第一个FA2选中的token(橙色)。

Fixed Attention的使用方法和上文介绍的Strided Attention的三种方法一致(交替使用、联合使用、多头使用),不再赘述。

作者们的结论:Strided Attention适用于图像、音频;Fixed Attention适用于文本。

理由如下:Strided Attention在attended token的位置上做了强假设:哪些位置的token应该被attened,与当前token位置强相关。作者们认为这种适合图像、音频这类数据。而在文本上这类假设不成立。所以在Fixed Attention中,哪些位置的token应该被attened,与当前token位置无关

讲的再简单点,图像、音频的局部信息很重要;而文本全局信息更重要。

总结:paper对新手不友好,简单的事情用了公式来解释,非常繁琐。希望本文能比原文容易理解一点。

2.2 Blockwise Self-Attention

Paper:Blockwise Self-Attention for Long Document Understanding (2019)

Key Contribution:通过分块来降低Softmax Attention的计算复杂度,方法简单,且实验效果较好。

前文提到了Transformer的时间复杂度和空间复杂度都为 。Blockwise Self-Attention这篇Paper对空间复杂度做了更细致的分析。

一个模型的Memory Usage主要来自三部分:Model Memory、Optimizer Memory、Activation Memory。按照Transformer模型通常使用的Adam类优化器来看,Optimizer Memory是Model Memory的三倍。这是因为Optimizer Memory需要为每个参数存储梯度first momentumsecond momentum

Model Memory和Optimizer Memory可以直接计算出来。比如对于Model Memory,可以直接通过模型大小与参数类型(如FP16、FP32、INT8)来推算出精确值。同理,Optimizer Memory也可以精确计算出。而Activation Memory则与具体实现相关。所以在Paper中,作者们用PyTorch的内存分析工具来看训练时总的内存开销,然后减去Model Memory和Optimizer Memory,以此来估算Activation Memory。

作者们以BERT-base为例,分析了Model Memory、Optimizer Memory、Activation Memory三者的占比,其中Activation Memory独占87.6%,属于内存开销最大的部分。画一幅图来总结上面提到的内容(注意图中的memory usage的比例是针对BERT-base而言的):

图2.2.1 BERT-base中内存分布示意图

我们说的空间复杂度 主要指的就是Activation Memory这一部分。因为Model Memory和Optimizer Memory是线性复杂度

Blockwise Self-Attention的核心思想非常简单:将一个长度为N的序列,平均分成n个短序列。当原始序列长度N无法被n除尽时,对原始序列进行padding,使它能被除尽。举一个例子来说明Blockwise Self-Attention的计算过程。

假设序列长度 , 每个token的维度为 。在Transformer中, Q、K、V三个矩阵的大小都为 。在Blockwise Self-Attention中, 假设分块数 , 那么每个分块中的序列长度为 。所以输入序列 可以划分为 个子序列: , 它们的大小都为 。同理可以把 ( 同理) 划分成 个子矩阵: , 它们的大小也都为 。在计算Self-Attention时, 每个 会去选择一个 来计算: (2.2.1)

在只有一个Attention头的情况下, 选择 的方法是:shifting one position。很简单, 选择 选择 选择 。换言之, 始终选下一个 ; 当 是最后一个block时, 选择 。这个过程可以用取余数的符号写出来, 但看着太繁琐, 所以文字描述了。

多头Attention情况下稍微麻烦一点。我们记序列 为单头Attention情况下每个 对应的 的编号 : (2.2.2)

仍以上面的示例为例, 在单头情况下, 的值为: (2.2.3)

它表示, 对应的 的值是2, 对应的 的值是 对应的 的值是1。

在多头情况下, 第 个头的 定义如下: (2.2.4)

例如, 按照上述示例, 第一个头的 为:(2.2.5)

第二个头的 为: (2.2.6)

因为分块数 , 所以需要取余数(注意下标从1开始, 所以余 0 时替换为 即可),得到最终的结果: (2.2.7)

过程其实很简单, 只是写出来稍微麻烦一点。

最后来分析复杂度。由本文第一部分分析Transformer复杂度的结论可知, 公式(2.2.1)中的复杂度为 。因为对每一个分块, 都需要用公式 (2.2.1) 进行计算, 所以总复杂度为: (2.2.8)

这既是计算复杂度,也是空间复杂度。

在原文中, 通常选为2。注意, 在大 计法中一般会忽略掉常数项。所以在这种意义下, Blockwise Self-Attention的复杂度仍为

但是大 计法的主要目的是理论分析, 并不为实际工程优化。所以即使在大 意义下复杂度没有变, 但它实际计算量仍然减少了。没有改变的 仍然意味着, Blockwise Self-Attention不能扩展到太大的 上, 这就是大 计法的作用。

具体来看, 当 时, RoBERTa的训练时间由原来的9.7天减少至7.5天。

总结:相比于Sparse Transformer中的Factorized Self-Attention,Blockwise Self-Attention更简单,且从效果上来看,优于Factorized Self-Attention。

2.3 Longformer

paper:Longformer: The Long-Document Transformer (2020)

Key Contribution:设计了多种不同的Local Attention和Global Attention方法。

首先重新看一下Factorized Self-Attention (2.1小节)中的两种Attention方法:Strided Attention和Fixed Attention。在Strided Attention中,又有两种Attention机制,在前文中我们把它们分别称为SA1和SA2(参考图2.1.2和2.1.3)。SA1的作用是Local Interaction,而SA2的作用是Global Interation。类似的,在Fixed Attention中(参考图2.1.4),FA1的作用是Local Interaction,而FA2的作用是Global Interation。

在Factorized Self-Attention中,它主要依靠两类Attention的组合使用来实现长距离依赖,例如SA1+SA2(或FA1+FA2)。

Longformer的核心idea和Factorized Self-Attention很像,只是Longformer中的部分Attention只有Local Interaction,没有Global Interaction。

Longformer一共提出了三种Attention,分别是Sliding Window based Attention(SW-Attention)、Dilated Sliding Window based Attention(DSW-Attention)和Global Attention(G-Attention)。下面分别介绍。

先看Sliding Window based Attention(SW-Attention),它其实和Strided Attention中的SA1完全一样。为了方便大家查看,重新把Strided Attention的SA1图copy一份到此处。

图2.3.1 SW-Attention示意图,它和Strided Attention中的SA1完全一样。图中L=3

SW-Attention只Attend它左边的L个token。在SW-Attention中,L被称为“窗口大小”,而在Strided Attention中,L被称为“步长(Stride)”,它们本质一样。

实际上我们可以在Transformer中只使用SW-Attention来构建具有Global Interaction的网络。其方法很简单,只需要堆叠多个SW-Attention网络层即可,就如同CNN增大感受野的方式。假设窗口大小为K,一个M层的SW-Attention结构中最顶层的“感受野”大小为KM,如图2.3.2所示。

图2.3.2 基于SW-Attention构建的Transformer的Global Interaction示意图

图中绿色方块表示当前token;蓝色线表示信息流; 每一层是上一层的输入;L假设为2。在第一层中,第一个token和第二个token的信息会 流入第二层中的第三个token。 而在第二层中,第二个token和第三个token的信息会流入下一层中的第四个token,以此类推。在最顶层(第五层),虽然当前token的信息只来自上一层的第四个token和第五个token,但从信息流的角度来看,它也隐含包含第一层中第一个和第二个token的信息。

可以看到,通过堆叠SW-Attention,Transformer也可以像CNN一样增加感受野。但是很容易想到,这种非常“间接”的方法不会有太好效果,就像CNN对长依赖建模的能力比较差一样。

再来看Dilated Sliding Window based Attention(DSW-Attention),它其实和Strided Attention中的SA2完全一样。为了方便大家查看,重新把Strided Attention的SA2图copy一份到此处。

图2.3.3 DSW-Attention示意图,它和Strided Attention中的SA2完全一样。图中Dilation=3。

DSW-Attention是“空洞”版的SW-Attention,就像空洞卷积和卷积之间的关系。简单来说,被attended的token不再像SW-Attention中是连续排列的,而是按等间距排列(间距称为“空洞率”,在图2.3.3中为3)。

与SW-Attention类似,通过堆叠多个DSW-Attention也能增大网络的感受野,从而实现Global Interaction。

最后再来看Global Attention(G-Attention)。G-Attention是SW-Attention的改进版,它的主要改动是:在SW-Attention基础上,增加了部分固定位置,使得这些位置的token需要 1)attend到其它所有token;2)被其它位置tokenattend到。如图2.3.4所示。

图2.3.4 G-Attention示意图,L=3

图中绿色token是SW-Attention会attend到的token。橙色token是在G-Attention中额外选中的token。以第五行的当前token为例(橙色),因为它是被额外选中的token,所以它会attend它左边的所有token。图中用黄色标出了相对于SW-Attention之外的额外被attended的token。此外,其它所有token也需要attend到第五个token,参见图中最后四行中的靠左黄色列。

图中第7行类似,大家可以自行对照图脑补一下这个过程。

在G-Attention中,哪些位置会被额外选中与具体下游任务相关。例如,在分类任务中,[CLS] token会被额外选中(Longformer一文中以RoBERTa为基础,将其中的Attention改为本文提到的Attention中的一种或多种);在问答任务中,所有问题的token都会被额外选中。

此外,G-Attention中有两份不同的QKV,一份用于计算由SW-Attention选中的token(图2.3.4中的绿色token),另一份用于计算由G-Attention额外选中的token(图2.3.4中的黄色token)。

上述提到的三种Attention的复杂度都为 , 因为哪些token会被attend与序列长度 无关。

2.4 Local attention and Memory-compressed attention

Paper: Generating wikipedia by summarizing long sequences (2018)

Key Contribution: 提出了Local Attention和Memory-compressed attention。Local Attention的计算复杂度随序列长度增长呈线性增长;Memory-compressed attention可以将计算复杂度减少固定常数倍(超参控制)。

2.4.1 Local Attention

前文中的2.3节也有一个Local Attention,但与此处的Local Attention方法不同。

此处Local Attention的核心思想是使用一个固定的分块大小n对输入序列进行分块,并限制self-attention的计算只能在各个分块内单独进行,如图2.4.1所示。

图2.4.1 Local Attention的模式图。图中假设序列长度N=12,分块大小n=3。

在图2.4.1中,每个位置的token只能attend到与它同颜色的其它token。例如图中第五行(红色标注行),它表示在Decoder结构中,对于输入序列中的第5个token的attention模式:第五行的灰色区域表示mask,这些mask表示Decoder结构中不能看到当前token之后的信息;前五个token根据颜色进行分块,每个token只能attend到同分块(同颜色)中的其它token,所以对于当前token而言(第五个token),它只能attend到第四个token和它自己(绿色部分)。

作为对比,标准的self-attention的模式图如下:

图2.4.2 标准self-attention的模式图。图中假设序列长度N=12。

标准的Decoder结构中,只有一个限制:所有token都不能attend到当前序列之后的token。

Local Attention与2.2节介绍的Blockwise Self-Attention比较类似,其核心思想都是对输入序列进行分块。Local Attention与Blockwise Self-Attention唯一的区别是:Local Attention将Self-attention的计算限制在组内;而Blockwise Self-Attention将Self-attention的计算限制在组间。

例如,考虑图2.4.1中的最后一行,在最简单的情况下,Blockwise Self-Attention的attention模式为:每个分块的token只能attend到下一个分块(蓝色token只能attend到绿色token;绿色token只能attend到橙色token;橙色token只能attend到黄色token;黄色token只能attend到蓝色token)。

下面分析一下Local Attention的复杂度。Local Attention通常选择一个固定长度的分块大小n(例如 )。假设总的序列长度为 , 那么分块数量为 。每一个分块的复杂度为 个分块的总复杂度为 。因为 为常数项, 所以Local Attention的复杂度随序列长度 呈线性增长

但在2.3节中, 曾分析到Blockwise Self-Attention的复杂度是 。为何两个如此相似的方法复杂度却有显著差异?

在Blockwise Self-Attention中, 的含义不是分块大小, 而是分块数量。所以每个分块的大小就为 。那么每个分块的attention计算复杂度就是 。又因一共有 个分块, 所以总复杂度是

这里之所以把这两个复杂度拿出来对比,是想说明:小心对待复杂性分析中的常量。不同视角可能会导致不同的分析结果。

复杂度唯一能体现的仅仅是:计算量与变量之间的关系

在上面的例子中,我们关心的变量是序列长度N,所以直接忽略了常数项n。但如果我们要比较这两个复杂度所对应的计算量时,常数项不能轻易忽略。

2.4.2 Memory-compressed Attention

在通常的基于Transformer的模型中,我们

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读197
粉丝0
内容8.2k