极市导读
全面从源码入手,重新来理解deformable attention >>加入极市CV技术交流群,走在计算机视觉的最前沿
大家好,今天来写写基于Transformer架构的端到端检测模型:Deformable Detr。
如果你也读过Deformable Detr的论文和代码,你可能会发现相比于原生detr,它确实不好理解。特别是这篇论文在写作时,抽象的东西太多,细节的部分较少(我看看谁是论文后半节各种公式的受害者)。而当你想从源码中一探更多细节时,你会发现除了代码本身也不太好读外,它和论文间还有些gap。
所以,本文在写作时,基本抛弃了论文,全面从源码入手,重新来理解deformable attention:
-
在本文第一章节,我们会快速回顾detr的基本思想,并讨论一个重要但鲜少被深入提及的问题:detr中的encoder和decoder到底学到了什么? -
在理解这一点的基础上,我们重新来端详deformable detr。来一步步捋顺它背后的巧思。
Deformable detr的核心之一是deformable attention,它借鉴了普通attention的框架,但实现起来却大相径庭。值得一提的是,虽然初次接触它时较难理解,但一旦明白它的运作流程,你就会发现它是比普通attention更加具象化的算法,因为你甚至能清楚想象它的每一步都学到了什么。
本文将根据源码重新绘制deformable detr的架构细节图,并假设读者已了解detr,全文目录如下:
一、快速回顾detr
1.1 detr整体架构
1.2 encoder和decoder到底学到了什么
二、Deformable detr
2.1 原生detr的问题
(1)训练时间长
(2)小物体检测效果不佳
2.2 解决训练时间长:deformable attention
(1)按方位寻找关键局部点
(2)单尺度特征图下的deformable attention
2.3 解决小物体检测问题:多尺度特征图
(1)生产多尺度特征图
(2)多尺度特征图上的坐标对齐问题
(3)多尺度特征图下的deformable attention
2.4 decoder上的deformable attention
三、精华代码放送
四、参考
一、快速回顾detr
1.1部分将快速回顾detr整体架构(会附上一些原始论文没有的,来自源码的细节)。1.2部分需要重点关注,它将有助于我们对deformable detr的具象化理解。
1.1 detr整体架构
DETR架构图如下(原图来自DETR论文,为了讲解方便,我手动添加了一些细节):
整体来说,detr遵循了transformer encoder-decoder架构,我们分别来看encoder和decoder在做一件什么事。
(1)Encoder
-
原始输入图片:我们假设原始图片尺寸为 (C, H, W)。 -
使用backbone做特征提取:通常使用resnet50作为backbone,它吃原始图片,吐出原始图片的特征图 -
对特征图添加二维位置编码:二维位置编码可以反应特征图中每一个像素点的行、列信息。detr除了提供三角函数式的二维位置编码外,也支持可学习位置编码。 -
一些CNN处理:在此基础上,再经过一些简单的卷积处理,就得到上图所绘的最终版特征图,其尺寸为 (hidden_dim, H_, W_) -
展平特征图,得到最终encoder的输入:如上图所示,展平后,encoder部分一共有 H_ * W_个尺寸为(1, hidden_dim)的向量。在接下来的描述中,我统一称呼这些向量为token
接下来,我们只要正常对这些token做multi-head self-attention的操作,就能得到encoder部分的输出了。细节略去不谈。
(2)Decoder
在传统CNN形式的目标检测中,一个物体可能会被检测出多个检测框,所以通常需要类似NMS这样的后置操作去消除冗余的检测框。而detr的目的则是实现end2end的目标检测:理想情况下,一个物体就出一个框,不需要任何后置操作。在这样的目标下,detr提出了object query的概念(见上图)。假设在现实世界中,一张图片的检测框最多只有n个。那么我们就把object query的数量设为n(detr的操作中设n=100)。然后让每个object query去找其唯一对应的真值检测框(bounding box)及框中物体的类别(class)。多出来的object query就代表“背景”。这里我们同样不再展开,有疑惑的朋友可以先去看detr的设计细节。以object query作为输入,剩下的过程就和传统transformer decoder差不多了,也不再赘述。
1.2 encoder和decoder到底学到了什么
了解detr的架构不难,但难的是理解它为何这么设计。比如有一个非常重要但鲜少被详细讨论的问题:detr的encoder和decoder到底看见了/学到了什么?为什么它们能让end2end的目标检测奏效?
理解这一点,对接下来讨论deformable detr的设计至关重要。 所以这里我们配合可视化的方法,来详细讨论一下这个问题。
(1)encoder学到了什么
从前文可知,encoder的每一个输入token,其实就是特征图中的某个像素点向量。
所以对于一个token来说,它在做self-attention的过程,就是寻找哪些token和它密切相关的过程。切换到图像的语境下,就是一个像素点,在寻找哪些像素点和它密切相关。
那怎么定义”密切相关”呢?当然是找到这个token对于其余每一个token的attention_weights(不难知道,对某个token来说它算出的attention_weights有H_*W_个)。然后我们对这一排attention_weights,将其reshape成(H_, W_)的形式,将(H_, W_)视作一张图,再根据其中每个值的大小去高量图中的每个像素点,这样我们不就能可视化出“每个像素点(token)都学到了些什么”吗?
按这个思路,我们在原始图像中取出一个像素点,计算出它在特征图中的像素位置,然后就能可视化它的学习情况了,例子如下:
这4个红点相关的token学到的信息如上图所示。以最右侧的红点为例,这个像素点在猫背上,经过无数轮学习后,它最终明白自己的关注重点应该是右边这只胖猪咪。
在有趣的可视化后,我们得到一个重要结论:在detr encoder中,每个token可粗理解为图中的某一个/块像素点,它遍历图中剩余的所有像素点(包括它自己),来学习自己应该特别关注哪些像素点。encoder的目的是掌握全局信息。
(2)decoder学到了什么
如果说encoder的每个输入token代表特征图的一个像素点的话,那么decoder的object query同样可以理解为特征图中的某一个像素点,只不过它是随机的。你可以将这些随机像素点理解成是检测框的质心,在训练的过程中,这些随机点不断去学习自己应该关注的部分,然后调整自己的位置。
我们用同样的方法,取出decoder和encoder做cross-attention部分的attention weights,这样我们就能可视化出某个object query到底在关注图片中的什么部分,可视化结果如下:
又是一个有趣的结论:在detr decoder中,每个object query,学习的是自己应该关注的检测框的范围。
二、Deformable Detr
2.1 原生detr的问题
在deformable detr的论文开头,直接点出了原生detr的两个重要问题:
(1)训练时间太长
(2)难以准确检测出小物体
单看论文,我们可能很难理解这两点,但有了第一章的基础,现在我们可以来详细阐释这两个问题了。
(1)训练时间太长
在原生detr中,特征图的一个像素点就是一个token,则我们一共有H_ * W_个token。而对这些token做attention计算的复杂度是
。这也意味着特征图越大,计算代价就越高。
但是,当你回顾前文所讲的“encoder部分到底学到了什么”,你一定会产生这样的疑惑:每一个token,真得有必要和其余所有token都做attention吗”。例如,下图中右侧胖猪咪背上的那个像素点,它其实只要和这只胖猪咪附近的像素点做attention,然后学到自己是胖猪咪背上的一个点就行了。它并不需要关注诸如沙发那样的像素点。
所以,对于一个token,它或许只用和附近的某些像素点做attention,就能提炼出相关信息。这样一来还能降低计算复杂度。
(2)难以检测出小物体
通过前面的分析,我们知道特征图不宜太大,这样会影响计算复杂度。
但是如果原始图片上一个目标物体太小,它在压缩后的特征图上可能就几乎找不到了,因此我们很难将其检测出来。所以一般情况下,小物体我们需要更大尺寸的特征图。
可是多大尺寸的特征图才行呢?这个问题我们也很难回答。所以,我们有一个异想天开的想法:让不同尺寸的特征图都参与训练,提高一个物体被检测出来的可能。
所以,针对detr的两个缺陷,deformable detr提出了两个改进方向:
-
减少训练时长:将全局attention转为局部attention -
提升小物体检测准确率:采用多尺度特征图
接下来我们将配合图例和源码解读,对这两个改进做详细说明。
2.2 解决训练时长问题:Deformable Attention
Deformable attention,直译过来为“可形变的注意力”,这是将全局attention转为局部attention的核心。我们先不着急来看它怎么算,我们先来看一个更直观的问题:该怎么才能让一个像素点(token)找到自己应该关注的那些局部像素点?
(1)按方位寻找关键局部点
理想情况下,对于图中红色像素点,它只需要关注那些绿色像素点,或许就能勾勒出一只胖猪咪的轮廓。
但是这些绿色点位于红色点的四面八方,红点很难一下找到它们。那现在我们不妨降低一点难度:我们以红点为中心,指引它去不同的方向上找到这些绿点。例如:
-
我们以红点为圆心,切分出8个方向。之后红点只需在每个方向上去找到它要关注的那些像素点即可。这样的找寻更有规律,而不是像之前一样无头无脑地一顿乱找。(为了表达简练,图中只给出两个方向上的示例) -
在每个方向上,红点都去找n_points个和它相关的绿点(在deformable detr中设n_points = 4)。这样在每个方向上,红点只需和n_points个绿点做attention,而不需要和全局像素点再做attention了。 -
在每个方向上,红点每次都去学习像素维度的偏移量(offsets)。这样红点即可通过自身像素坐标和偏移量计算出该方向上每个绿点的位置。
从上面的解释中,我们不难理解:在这样的设计下,红点不再和全局像素点做attention。而是尽量绕着上下左右的方向去找关键像素点做attention。而模型是不断学习的,每个方向上的关键像素点是不断在改变的,因此红点在每个方向上看到的视野一直在变化。故而我们管这样的attention机制叫“deformable attention”(可变形的自注意力)。它可类比于CNN中的可变形卷积,因为它的本质也是改变感受野。
你可能想问:那我们怎么让模型找到方向信息呢? 大家还记得multi-head attention中的head吗!在nlp中,每个head可能关注自然语言的不同部分(语义、时态等)。而在deformable attention中,每个head就可以用来学习不同方向的关键点信息! 是不是很有趣,这也是我文章开头所说,deformable attention其实是比普通attention来说更具象化的东西。
(2)单尺度下的deformable attention
到目前为止,你可能对deformable attention有了初步感知,但还比较难清晰地把它对到模型架构实现上。所以现在我们来具体解读下单尺度下的deformable attention是怎么实现的。
首先,类比于正常attention,我们有一个用于计算value的线形矩阵Wv,其尺寸为(hidden_dim, hidden_dim)。我们对所有的token都计算其value值。在下面的图例中,我们把value值reshape成(hidden_dim, H_, W_)的形式,是为了方便下文的理解,不影响大家对value的正常解读。
然后,我们来计算“单尺度的deformable attention”,单尺度意味着我们依然只用某个大小的特征图来作为encoder输入,例如大小为(hidden_dim, H_, W_)。整体过程如下:
如前文所说,我们使用backbone,提取原始图片的特征图。
-
特征图中的每一个像素,都可以被视为2.2(1)中所说的红点(在上图中依然也是用红点表示)。每个像素都有自己对应的像素坐标。对于每个像素,模型都要去学习它在各个方向上的偏移量,然后再根据偏移量计算出偏移点(图中绿点)。我们遵循源码中的表示,称这样用像素坐标表示的像素点为reference_points(参考点) -
特征图中的每一个像素,相当于encoder输入中的一个token,有其对应的token向量表示。 -
综上,特征图中的每个像素点都有两种表示:作为encoder输入的token向量,以及表示其像素坐标的reference_points。
【注:在代码实现时,reference_points并不是用绝对像素坐标表示的,而是用了一种类似于归一化的形式,这点我们放在后文说,此处不影响大家对其概念的理解】
我们的以其中一个token向量(像素点)为例,如前文所说,它被拆成8个head,分别去找8个方向上的绿点(如前文所说)。一个像素点在一个head上,只和它找到的这些绿点做attention。类比于一般的attention需要去和全局像素点做attention,这种操作一下将attention的计算复杂度从O(N^2)降至约O(N)。
好,我们知道计算attention离不开两个要素:attention score和value。则此时对于一个token向量(像素点)的某一个head(方向):
-
它先找到了自己需要关注的4个偏移点(绿点),更准确地说,它找到了这4个绿点的像素坐标。 -
有了像素坐标,我们就可以通过“插值计算”的方式,从这份特征图的value结果(图中蓝色方块),找出这4个绿点的value值。 -
对于attention score,deformable attention和普通attention的计算方式不同:普通attention是通过q,v值去计算attention score的;而deformable attention则是直接去学习attention score(图中紫色块)。同时通过softmax的方式,保证所有attention score的sum为1。 -
现在对于一个token向量(像素点),它在一个head(方向)上的value和attention score都有了。它就能计算attention score * value的结果了,这一点和普通attention一样。 -
最后将计算结果concat起来,就能得到这个向量经过deformable attention之后的输出结果了。
现在,大家是不是能更好体会到,虽然deformable attention乍看上去不好理解,但一旦理清了它的运作流程,就会发现它其实是比普通attention更加具象化的算法,因为你甚至能清楚理解你的每一个token在学什么,又是怎么学到的。
2.3 解决小物体检测问题:多尺度特征图
前面我们说过,对于小物体,若我们使用较小的特征图,它可能就被压缩成特征图上的一个像素点,因此很难被我们检测出来。所以小物体我们期望用更大尺度的特征图。但由于我们并不知道具体要使用多大的特征图才合适,所以我们有了一个异想天开的想法:把各种不同尺寸的特征图都试一次,然后把在每个特征图上的结果综合起来,这样不就能提升小物体被检测出来的概率了吗。
例如,对于图中的遥控器部分的某个像素点。我们在不同尺度的特征图上,都执行类似于2.2(2)中所描绘的单尺度deformable attention那样的操作,然后把结果通过某种方式综合起来,这样我们就不怕遥控器找不到了。
我们来更具象化地看一下这个过程。
(1)产出多尺度特征图
在deformable detr中,选择采用4种不同尺寸的特征图。上图刻画了四种不同尺寸的特征图产出的过程。具体来说:
-
我们从backbone(通常是resnet50)的不同layer中先产出其中3种尺寸的特征图 -
最后一种尺寸的特征图由backbone最后一层产出的特征图 + 一些卷积操作而来。 -
在图中你会发现我们对原始特征图还做了一层卷积操作,你可以理解成是一种线性变化,目的是把C维度的尺寸映射成和hidden_dim一致。
(2)多尺度特征图上的坐标对齐
我们知道,多尺度特征图的意义是,假设现在对于某张图片,你在它的某个尺寸的特征图上随机取一点。这一点在原图中代表的物体可能非常小(比如是遥控器的按钮),经过压缩后,它在特征图上可能就只剩这一点了。因此模型很难将它检测出来。
所以现在我们的想法是,如果这个点在这个尺寸的特征图上难以被检测,我们就把它放到别的尺寸的特征图上都试一下,也许在别的尺寸的特征图上,这个遥控器按钮就不再是一个点了,而是一小片区域,这样我们就更可能将其检测出来。
总结一下,在多尺度的特征图上,我们想做的事情是:
-
首先,对于一张图片,产出4种不同尺寸的特征图(如前文所说,4是模型的一个超参) -
对每个尺寸的特征图上的任意一点,我们都希望找到它在其余特征图中的位置(像素坐标)。这样,这一点不仅可以在本特征图中做deformable attention,还可以在别的尺寸的特征图中做deformable attention。 -
最后我们通过某种方式,将这一点在各个尺度的上的deformable attention结果综合起来,就能得到最后的结果。
第三点我们将在后文细说。本节我们重点来看第二点,即我们需要回答一个问题:对于某一尺寸特征图上的任一点,你怎么找到它对应在其它特征图中的像素坐标?
很容易知道,运用绝对像素的表达方式(例如位于第几行第几列个pixel)是无法满足这点的。所以我们自然而然想到:那就使用相对像素的表达方式,对像素点坐标做归一化,这样在不同的特征图中,相同的像素点坐标就能差不多表示原图的同一个位置了。
如果你的batch_size = 1,那么或许可以考虑这么做。但是当batch_size > 1时,问题就出现了:我们知道在一个batch内我们会做padding的操作,对于这些不同尺寸的特征图也是如此(被padding部分的像素拉平后,就是一个为0的token向量,和nlp中的padding是一样的)。对于一条数据,它的每个特征图被padding的比例是不一样的,这时我们就不能直接用归一化方法了,例子如下:
图例中给出了一条数据下的2种不同尺寸的特征图,我们以0.5像素单位为格子,划分这两张特征图。一般情况下它们被padding的比例不一样。图中红色部分表示原始特征图,其余部分表示被padding的地方。
对于左图中H方向为3.5的点,采用归一化方式后,它的H轴坐标为3.5/5.0 = 0.7。那么现在回到右侧更小的特征图中,归一化后值为0.7的坐标的绝对坐标为:0.7 * 3 = 2.1。
发现了吗:在左侧特征图中,明明还在非padding部分的点,在右侧就变成去padding部分了。这就造成了一定的偏差。
所以,更准确地做法是,我们依然还用归一化的形式来表示像素坐标,但这个归一化是相对于非padding部分(我称其为有效部分)。可能你觉得有点抽象,那么我们来看一个具体的例子(都以H轴为例,W轴同理可推)。
假设对于batch中的某条数据:
-
特征图1的高度为H1,有效高度为HE1,其上某个点绝对坐标为h1。则该点归一化坐标可以表示成h1/H1 -
特征图2的高度为H2,有效高度为HE2,那么特征图1中的这个点在特征图2中,对应的归一化坐标应该是多少?
【解】:
-
设特征图1中的这个点,在特征图2中的绝对坐标为h2。 -
我们希望用有效部分来做坐标归一化,即我们希望:h1/HE1 = h2/HE2。这个的意思是,如果一个点在特征图1中是在非padding部分,那么在特征图2中理应也在非padding部分。反之同理。 -
则我们有:h2 = (h1/HE1) * HE2 -
进一步我们有:(h2/H2) = (h1/HE1) * (HE2/H2)。而HE2/H2就是特征图2上的有效部分占比 -
h2/H2就是特征图1上的这个点,在特征图2上对应的归一化坐标。
如果看到这里你还是觉得很懵,你可以再回顾一下我们的目标:对于一条数据,它有4个特征图,我们取出其中某一特征图,我们在上面任取一点,我们希望能计算出这一点在其余特征图上的坐标表示。牢记这个目标,再回头看一遍上面的过程,你就能有更清楚的理解了。这其实就是源码中get_reference_points的实现方式,很多朋友初次看源码时可能理解不了它在做什么,这一块就是对源码的解读。
(3)多尺度特征图上的deformable attention
有了前面的介绍打底,现在我们可以直接来看多尺度特征图上的deformable attention是如何实现的了,细节如下图:
我们从下往上,来整理一下整个过程:
-
产出多尺度特征图。使用“backbone + 二维函数式位置编码(也可换成可学习的)+ 尺度层级位置编码”产出不同尺度的特征图。细节我们在前文已阐述过,这里不赘述。只额外提一点:这里会新增一个尺度层级位置编码,用于表示当前token位于第几个特征图上。这是一个可学习的位置编码。实现不难,这里不再细究,大家可看源码。
-
把所有特征图上的像素点展开成一排token向量。即图中所绘制的红、黄、绿、紫四类向量。它们排成一排,作为deformable attention的输入。
-
学习attention_weights。每一个token(像素点)的每一个head(方向)上都需要学习尺寸为(num_levels = 特征图数量, n_points = 偏移点)大小的注意力权重。这num_levels * n_points个注意力权重将一起做softmax,使得它们的和为1。
-
学习偏移量。
-
首先,对于一个token(像素点),我们能得到它在自己所属的特征图中的归一化像素表示 -
然后,使用2.3(2)节介绍过的坐标对齐方法,计算得到这个token在其余特征图中的归一化像素坐标表示。 -
接着,对于这个token(像素点),我们让它在每个尺度下都去学习4个偏移量(图中绿点)。这个我们在2.2(2)中介绍的单尺度特征图下的操作一致。 -
最后,我们再复习一下这样做的原因:一个token(像素点)在自己所属的特征图中,不一定能学到东西;所以我们把它放到不同尺寸的特征图中,让它都学一遍,然后汇总它学到的知识。 -
通过插值方法,取得偏移量的value值。
-
目前为止,对于某个token(像素点),我们得到了它在不同特征图上的不同偏移量结果(按照我们的超参设置,在token的每个head/方向上,我们一共有4*4 = 16个偏移量结果)。 -
因为这些偏移量来自不同的特征图,所以为了我们自然要去不同特征图的value中,通过插值的方法,把这些偏移量对应的value取出来。因此你会在上图中看见红、黄、绿、紫四个不同颜色的方块,它们表示不同特征图的value值(具体计算方法在2.2(2)中已经给出)。 -
计算每个head(方向)上的attention_weights *value值,得到这个head最终的输出结果。细节画在图中了,这里不再赘述。
-
concat各个head(方向)上的结果,得到这个token经过deformable attention后最终的输出。
-
图中其余特征图上的token(黄绿紫向量),也是按照这个流程计算。
到此为止,我们就deformable detr encoder部分的最重要的内容介绍完毕了。当然代码中涉及到一些细节。为了不干扰大家阅读,我把源码相关的内容都放在本文最后一章。
2.4 decoder上的deformable attention
现在我们来快速看下deformable detr decoder的部分,在这一块中:
Decoder self-attention部分,采用的是正常的attention,因为object query一般就设为100个,它们之间做attention不会消耗太多资源。
-
Decoder cross-attention部分,采用的是deformable attention。
你应该发现一个有意思的问题了:采用deformable attention,意味着我把输入token当作是特征图上的一个像素点,它在各个特征图上都能找到对应的坐标,有了这些信息,我才能继续做deformable attention。可是decoder的输入是我随机初始化的,它又没有对应的特征图,我该怎么给这些object query附上这些信息呢?
为了解答这个问题,我们再来回顾本文开头1.2(2)的部分:decoder学到了什么?还记得在这一部分中,我们如何根据可视化结果解释decoder中的object query吗?答案是,你可以将其视为在特征图上随机的一点,在模型学习的过程中,这个点在不断调整自己看见的检测框的范围。从这个意义上,一个object query可以理解成能代表它所找的检测框的质心(或者和检测框相关的别的点位),它将不断调整质心的位置(只是在原生detr中,不能显像看到这一点)。
但是在deformable attention中,这一点被显象化了。你可以发现,在对object query做坐标相关的初始化时,我们相当于在特征图上随机找一点来附值给它,然后让它能继续做deformable attention的相关操作。同时,我们学习了一个用于“找坐标”的参数矩阵,让object query在每次学习时,都能调整自己的位置。
可能有些朋友在阅读源码时,很难理解源码中相关部分的操作。结合前文的这些描述,现在大家应该能更好阅读decoder部分的代码了。在我看来decoder部分最难理解的就是这一点,所以剩余的细节,我就不展开了。
三、精华代码放送
为减少篇幅,下面为大家放送我觉得代码中最精华的部分:
-
Deformable attention的计算 -
不同特征图上的像素坐标对齐
其余的数据处理、位置编码等之类就不再详细说了,大家可以自行阅读。所有的代码都详细标出了输入输出的参数含义、维度、样例。基本每行代码都做了注释,尽量减少因为裁剪代码而造成的上下文信息缺失。
(1)deformbale attention的计算
这节中还有一个有趣的细节大家可以看看:那就是怎么通过参数初始化,来区分head所代表的8个方向,并且让之后学到的偏移点都大致在这8个方向周围分布。
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels (特征图的数量)
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
每个特征图的每个attention上需要sample的points数量
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
# =============================================================================
# 每个query在每个head每个特征图(n_levels)上都需要采样n_points个偏移点,每个点的像素坐标用(x,y)表示
# =============================================================================
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
# =============================================================================
# 每个query用于计算注意力权重的参数矩阵
# =============================================================================
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
# =============================================================================
# value的线性变化
# =============================================================================
self.value_proj = nn.Linear(d_model, d_model)
# =============================================================================
# 输出结果的线性变化
# =============================================================================
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
# =============================================================================
# sampling_offsets的权重初始化为0
# =============================================================================
constant_(self.sampling_offsets.weight.data, 0.)
# =============================================================================
# thetas: 尺寸为(nheads, ),假设nheads = 8,则值为:
# tensor([0*(pi/4), 1*(pi/4), 2*(pi/4), ..., 7 * (pi/4)])
# 好似把一个圆切成了n_heads份,用于表示一个图的nheads个方位
# =============================================================================
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
# =============================================================================
# grid_init: 尺寸为(nheads, 2),即每一个方位角的cos和sin值,例如:
# tensor([[ 1.0000e+00, 0.0000e+00],
# [ 7.0711e-01, 7.0711e-01],
# [-4.3711e-08, 1.0000e+00],
# [-7.0711e-01, 7.0711e-01],
# [-1.0000e+00, -8.7423e-08],
# [-7.0711e-01, -7.0711e-01],
# [ 1.1925e-08, -1.0000e+00],
# [ 7.0711e-01, -7.0711e-01]])
# =============================================================================
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
# =============================================================================
# 第一步:
# grid_init / grid_init.abs().max(-1, keepdim=True)[0]:计算8个head的坐标偏移,尺寸为torch.Size([n_heads, 2])
# 结果为:
# tensor([[ 1., 0.],
# [ 1., 1.],
# [0., 1.],
# [-1., 1.],
# [-1., 0.],
# [-1., -1.],
# [0., -1.],
# [1., -1]])
# 然后把这个数据广播给每个n_level的每个n_point
# 最后grid_init尺寸为:(nheads, n_levels, n_points, 2)
# 这意味着:在第一个head上,每个level上,每个偏移点的偏移量都是(1,0)
# 在第二个head上,每个level上,每个偏移点的偏移量都是(1,1),以此类推
# =============================================================================
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
# =============================================================================
# 每个参考点的初始化偏移量肯定不能一样,所以这里第i个参考点的偏移量设为:
# (i,0), (i,i), (0,i)...(i,-i)
# grid_init尺寸依然是:(nheads, n_levels, n_points, 2)
# 现在意味着:在第一个head上,每个level上,第一个偏移点偏移量是(1,0), 第二个是(2,0),第三个是(3,0), 第四个是(4,0)
# 在第二个head上,每个level上,都一个偏移点偏移量是(1,1), 第二个是(2,2), 第三个是(3,3), 第四个是(4,4)
# =============================================================================
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
# =============================================================================
# 初始化sampling_offsets的bias,但其不参与训练。尺寸为(nheads * n_levels * n_points * 2,)
# =============================================================================
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
# =============================================================================
# 其余参数的初始化
# =============================================================================
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
Args:
query:原始输入数据 + 位置编码后的结果,尺寸为(B, sum(所有特征图的token数量), 256)
sum(所有特征图的token数量)其实就是sum(H_*W_)
reference_poins:尺寸为(B, sum(所有特征图的token数量), level_num, 2)。表示对于 batch中的每一条数据的每个token,它在不同特征层上的坐标表示。
请一定参见get_reference_points函数相关注释
input_flatten: 原始输入数据,尺寸为(B, sum(所有特征图的token数量), 256)
input_spatial_shapes: tensor,其尺寸为(level_num,2)。 表示原始特征图的大小。
其中2表是Hi, Wi。例如:
tensor([[94, 86],
[47, 43],
[24, 22],
[12, 11]])
input_level_start_index: 尺寸为(level_num, )
表示每个level的起始token在整排token中的序号,例如:
tensor([0, 8084, 10105, 10633])
input_padding_mask: mask信息,(B, sum(所有特征图的token数量))
"""
# =============================================================================
# N:batch_size
# len_q: query数量,在encoder attention中等于len_in
# len_in: 所有特征图组成的token数量
# =============================================================================
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
# 声明所有特征图的像素数量 = token数量
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
# =============================================================================
# self.value_proj:线性层,可理解成是Wq, 尺寸为(d_model, d_model)
# value:v值,尺寸为(B, sum(所有特征图的token数量), 256)
# =============================================================================
value = self.value_proj(input_flatten)
# =============================================================================
# 对于V值,将padding的部分用0填充(那个token向量变成0向量)
# =============================================================================
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
# =============================================================================
# 将value向量按head拆开
# value:尺寸为:(B, sum(所有特征图的token数量), nheads, d_model//n_heads)
# =============================================================================
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
# =============================================================================
# self.sampling_offsets:偏移点的权重,尺寸为(d_model, n_heads * n_levels * n_points * 2)
# 【对于一个token,求它在每个head的每个level的每个偏移point上
# 的坐标结果(x, y)】
# 由于sampling_offsets.weight.data被初始化为0,但sampling_offsets.bias.data却被初始化
# 为我们设定好的偏移量,所以第一次fwd时,这个sampling_offsets是我们设定好的初始化偏移量
# self.sampling_offsets(query) = (B, sum(所有特征图的token数量), d_model) *
# (d_model, n_heads * n_levels * n_points * 2)
# = (B, sum(所有特征图的token数量), n_heads * n_levels * n_points * 2)
# =============================================================================
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
# =============================================================================
# self.attention_weights: 线性层,尺寸为(d_model, n_heads * n_levels * n_points),
# 初始化时weight和bias都被设为0
# 因此attention_weights第一次做fwd时全为0
# =============================================================================
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
# =============================================================================
# attention_weights: 表示每一个token在每一个head的每一个level上,和它的n_points个偏移向量
# 的attention score。
# 初始化时这些attention score都是相同的
# =============================================================================
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# =============================================================================
# reference_points:尺寸为(B, sum(所有token数量), level_num, 2)
# N, Len_q, n_heads, n_levels, n_points, 2
# =============================================================================
if reference_points.shape[-1] == 2:
# ======================================================================
# offset_normalizer: 尺寸为(level_num, 2),表示每个特征图的原始大小,坐标表达为(W_, H_)
# ======================================================================
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
# ======================================================================
# 先介绍下三个元素:
# reference_points[:, :, None, :, None, :]:
# 尺寸为 (B, sum(token数量), 1, n_levels, 1,2)
# sampling_offsets:
# 尺寸为(B, sum(token数量), n_heads, n_levels, n_points, 2)
# offset_normalizer[None,None,None,:,None,:]:
# 尺寸为(1, 1, 1,n_levels, 1,2)
# 再介绍下怎么操作的:
# (1)sampling_offsets / offset_normalizer[None, None, None, :, None, :]:
# 前者表示预测出来的偏移量(单位是像素绝对值)通过相除,把它变成像素归一化以后的维度
# (2) 加上reference_points:表示把该token对应的这个参考点做偏移,
# 得到其在各个level上的n_points个偏移结果,偏移结果同样是用归一化的像素坐标来表示
# sampling_locations:尺寸为(B, sum(tokens数量), nhead, n_levels, n_points, 2)
# ======================================================================
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output
MSDeformAttnFunction.apply细节如下:
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
"""
Args:
value: 尺寸为:(B, sum(所有特征图的token数量), nheads, d_model//n_heads)。对原始token做线性变化
后的value值
input_spatial_shapes: tensor,其尺寸为(level_num,2)。 表示原始特征图的大小。
其中2表是Hi, Wi。例如:
tensor([[94, 86],
[47, 43],
[24, 22],
[12, 11]])
sampling_locations:尺寸为(B, sum(tokens数量), nhead, n_levels, n_points, 2)。
每个token在每个head、level上的n_points个偏移点坐标(坐标是归一化的像素值),
每个坐标是按(w,h)表达的,注意不是(h,1)
attention_weights: 尺寸为(B, sum(tokens数量), nheads, n_levels, n_points)
每个token在每个head、level上对n_points个偏移点坐标的注意力权重
"""
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
# ================================================================================
# 截取出每个特征图的value
# value_list: tuple[Tensor],value_list长度 = n_levels
# 每个tensor尺寸为(B, sum(某个特征图token数量), nheads, d_model//n_heads)
# ================================================================================
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
# ================================================================================
# 原来我们归一化后,坐标是在H=(0,1), W=(0,1)这个假设下的,现在我们想让H = (-1, 1), W = (-1,1),
# 所以对所有的坐标也要做相应处理
# sampling_grids: 尺寸依然是(B, sum(tokens数量), nhead, n_levels, n_points, 2)
# ================================================================================
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# (B, H_*W_, nheads, head_dim) -> (B, H_*W_, nheads*head_dim) -> (B, n_heads*head_dim, H_*W_) -> (B*nhead, head_dim, H_, W_)
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
# (B, sum(所有token), nheads, n_poins, 2) -> (B, nheads, sum(所有token), n_points, 2) -> (B*nhead, sum(所有token), n_points, 2)
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
#(B*nhead, head_dim, sum(所有token), n_points)
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)cc
# (B, sum(token数量), nheads, n_levels, n_points) -> (B*nheads, sum(token数量), n_levels, n_points) -> (B, nheads, 1, sum(token数量), n_levels * n_points)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
# ================================================================================
# (1)torch.stack(sampling_value_list, dim=-2).flatten(-2):
# 尺寸为(B*n_heads, head_dim, sum(token数量), n_levels*n_points)
# (2)乘上attention_weights =
# (B*n_heads, head_dim, sum(token数量), n_levels*n_points) *
# (B*n_heads, 1, sum(token数量), n_levels*n_points)
# = (B*n_heads, head_dim, sum(token数量), n_levels*n_points)
# (3).sum(-1): 处理后尺寸为(B*n_heads, head_dim, sum(token数量))
# (4)最后把output处理成:(B, nheads*head_dim, sum(tokens))
# ================================================================================
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
# ================================================================================
# 经过转置以后,最后的deformable attention的output变成(B, sum(tokens),nheads*head_dim)
# ================================================================================
return output.transpose(1, 2).contiguous()
(2)不同特征图上的像素坐标对齐
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""
Args:
spatial_shapes: tensor,其尺寸为(level_num,2)。 表示原始特征图的大小。
其中2表是Hi, Wi。例如:
tensor([[94, 86],
[47, 43],
[24, 22],
[12, 11]])
valid_ratios: 尺寸为(B, level_num, 2),
用于表示batch中的每条数据在每个特征图上,分别沿着H和W方向的
有效比例(有效 = 非padding部分)
例如特征图如下:
1, 1, 1, 0
1, 1, 1, 0
0, 0, 0, 0
则该特征图在H方向上的有效比例 = 2/3 = 0.6
在W方向上的有效比例 = 3/4 = 0.75
"""
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
# =========================================================================
# (1)torch.linspace(0.5, H_ - 0.5, H_): 可以看成按0.5像素单位把H方向切割成若干份
# (2)torch.linspace(0.5, W_ - 0.5, W_): 可以看成按0.5像素单位把W方向切割成若干份
# 例如设H_, W_ = 12, 16, 则(1)和(2)分别为:
# tensor([ 0.5000, 1.5000, 2.5000, 3.5000, 4.5000, 5.5000, 6.5000, # 7.5000, 8.5000, 9.5000, 10.5000, 11.5000])
# tensor([ 0.5000, 1.5000, 2.5000, 3.5000, 4.5000, 5.5000, 6.5000, # 7.5000, 8.5000, 9.5000, 10.5000, 11.5000, 12.5000, 13.5000, # 14.5000, 15.5000])
#
# 你可以想成把一张特征图横向划几条线,纵向画几条线。对于一个像素格子,我们用其质心坐标
# 表示它,这相当于是这些线的交界点
#
# 这里ref_y表示每个ref点的x坐标(H方向坐标),ref_x表示每个ref点的y坐标(W方向坐标)
# (3) ref_y: 尺寸为(H_, W_), 形式如:
# tensor([[ 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, # 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, # 0.5000, 0.5000],
# [ 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, # 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, ¥ 1.5000, 1.5000],
# ...
# [11.5000, 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, # 11.5000,
# 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, # 11.5000]])
#
# (4) ref_x:尺寸为(H_, W_),形式如:
# tensor([[ 0.5000, 1.5000, 2.5000, 3.5000, 4.5000, 5.5000, 6.5000, # 7.5000, 8.5000, 9.5000, 10.5000, 11.5000, 12.5000, 13.5000, # 14.5000, 15.5000],
# [ 0.5000, 1.5000, 2.5000, 3.5000, 4.5000, 5.5000, 6.5000, # 7.5000, 8.5000, 9.5000, 10.5000, 11.5000, 12.5000, 13.5000, # 14.5000, 15.5000],
# ...(重复下去)
# =========================================================================
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
# =========================================================================
# 相当于每个像素格子都用其中心点的坐标来表示它
# ref_y.reshape(-1)[None]: 把(H_, W_)展平成(1, H_*W_)。
# 例如H_=12, W_=16, 则展平成(1, 192)
# tensor([[ 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, # 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, # 0.5000, 0.5000,
# 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, # 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, 1.5000, # 1.5000, 1.5000,
# ...
# 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, # 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, 11.5000, # 11.5000, 11.5000]])
#
# ref_x.reshape(-1)[None]:把(H_, W_)展平成(1, H_*W_)。例子同上
# tensor([[ 0.5000, 1.5000, 2.5000, 3.5000, 4.5000, 5.5000, 6.5000, # 7.5000, 8.5000, 9.5000, 10.5000, 11.5000, 12.5000, 13.5000, # 14.5000, 15.5000, 0.5000, 1.5000, 2.5000, 3.5000, 4.5000, # 5.5000, 6.5000, 7.5000, 8.5000, 9.5000, 10.5000, 11.5000, # 12.5000, 13.5000, 14.5000, 15.5000,
# ...
#
# valid_ratios[:, None, lvl, 1]:
# 取出batch中所有数据在当前lvl这层特征图上,H方向的有效比例(有效=非padding)
# 尺寸为(B, 1),例如:
# tensor([[0.6667],
# [0.6667],
# [0.9167],
# [1.0000]])
# 乘上H_后表示实际有效的像素级长度
# valid_ratios[:, None, lvl, 0]:也是同理
#
# ref_y: 尺寸为(B, H_ * W_)。
# 表示对于batch中的每条数据,它在该lvl层特征图上一共有H_*W_个参考点,ref_y
# 表示这些参考点最终在H方向上的像素坐标。【但这里像素坐标做了类似归一划的处理。
# ref_y = 原始H方向的绝对像素坐标/H方向上有效即非padding部分的绝对像素长度
# 因此该值如果 > 1则说明该参考点在padding部分】
# ref_x:尺寸为(B, H_ * W_)。同上
# =========================================================================
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
# =========================================================================
# ref:尺寸为(B, H_*W_, 2),表示对于batch中的每条数据,它在该lvl层特征图上所有H_*W_个参考点的x,y坐标
# 如上所说,该坐标已经处理成相对于有效像素长度的形式
# 【特别注意!这里W和H换了位置!!!!!!】
# =========================================================================
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
# =========================================================================
# 尺寸为:(B, sum(H_*W_), 2)。表示对于一个batch中的每一条数据,
# 它在所有特征图上的参考点(数量为sum(H_*w))的像素坐标x,y。
# 这里x,y都做过处理,表示该参考点相对于非padding部分的H,W的比值
# 例如,如果x和y>1, 则说明该参考点在padding部分
# =========================================================================
reference_points = torch.cat(reference_points_list, 1)
# =========================================================================
# 尺寸为:(B, sum(H_*W_), level_num, 2)。表示对于batch中的每一条数据,
# 它的每一个特征层上的归一化坐标,在其他特征层(也包括自己)上的所有归一化坐标
# 假设对于某条数据:
# 特征图1的高度为H1,有效高度为HE1,其上某个ref点x坐标为h1。则该ref点归一化坐标可以表示成h1/H1
# 特征图2的高度为H2,有效高度为HE2,那么特征图1中的ref点在特征图2中,对应的归一化坐标应该是多少?
# 【常规思路】:正常情况下,你可能觉得,我只要对每一张特征图上的像素点坐标都做归一化,然后对任意两张特征图,
# 我取出像素点坐标一致的ref点,它不就能表示两张特征图的相同位置吗?
# 【问题】:每张特征图padding的比例不一样,因此不能这么做。举例(参见草稿纸上的图)。特征图1上绝对像素
# 位置3.5的点,在特征图2上的位置是2.1,在有效图部分之外
# 【正确做法】:把特征图上的坐标表示成相对于有效部分的比例
# 【解】:我们希望 h1/HE1 = h2/HE2,在此基础上我们再来求h2/H2
# 则我们有:h2 = (h1/HE1) * HE2,进一步有
# (h2/H2) = (h1/HE1) * (HE2/H2),而(h1/HE1)就是reference_points[:, :, None],
# (HE2/H2)就是valid_ratios[:, None]
# 所以,这里是先将不同特征图之间的映射转为“绝对坐标/有效长度”的表达,然后再转成该绝对坐标在整体长度上的比例
# =========================================================================
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
参考
1、https://arxiv.org/abs/2010.04159
2、https://github.com/fundamentalvision/Deformable-DETR
3、https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb#scrollTo=8yls9cpVYTEg

公众号后台回复“数据集”获取100+深度学习各方向资源整理
极市干货

点击阅读原文进入CV社区
收获更多技术干货

