
极市导读
本文作者提供了关于Vision Transformer在Out-of-Distribution数据上泛化能力的全面研究。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
最近,Vision Transformer(ViT)在各类计算机视觉任务中都取得了令人瞩目的成果。然而,人们对它们在面对不同 Distribution Shifts 下的泛化能力知之甚少。因此,北京航空航天大学、NTU S-Lab等机构的工作提供了关于 Vision Transformer 在out-of-distribution数据上泛化能力的全面研究。
论文链接:https://arxiv.org/abs/2106.07617
GitHub地址:https://github.com/Phoenix1153/ViT_OOD_generalization
一、Summary
为了支持系统的研究,我们首先对Distribution Shift进行分类,将它们划分为五个概念组:background shift,corruption shift,texture shift,destruction shift和style shift。
接下来,我们对不同distribution shift下的Vision Transformer进行了广泛的评估,并将其泛化能力与卷积神经网络(CNN)模型进行了比较,得到了几个重要的观察结果:
-
在大多数Distribution Shift下,Vision Transformer的泛化能力优于CNN。在相同或更少的参数量下,在大多数类型的distribution shift下,Vision Transformer的 top-1 正确率比对应的CNN领先5%至10%。
-
随着Vision Transformer模型结构逐渐变大,它面对正常分布(in-distribution)数据和out-of-distribution数据的泛化性能差距会逐渐缩小。
-
然后,为了进一步提高Vision Transformer的泛化能力,我们分别设计了结合对抗学习、信息论和自监督学习的三种泛化能力提升的Vision Transformer。通过研究这三种类型的泛化增强Vision Transformer,我们观察到了Vision Transformer模型针对梯度的敏感性,并设计一个更平滑的学习策略,以实现稳定的训练过程。通过修改的训练方案,我们实现了相较于原始Vision Transformer在out-of-distribution数据下的泛化性能4%左右的提升。
通过将这三种泛化增强的 Vision Transformer 与它们对应的 CNN 模型进行综合比较,得到以下结论:
-
对于泛化增强的Vision Transformer,模型结构越庞大,其对于out-of-distribution数据的泛化能力得到的增益更多。 -
与相应的CNN模型相比,泛化增强的 Vision Transformer 对超参数更敏感。希望我们的综合研究能够为设计更一般化的学习架构提供启发。
二、Distribution Shift的分类方法
为了对Out-of-Distribution数据上的泛化能力进行广泛的研究,我们基于原始图像被修改的语义概念,对Distribution Shift进行了分类。在图像分类任务中,图像通常由前景的物体和背景组成。以往的工作通常假设在图像中出现的语义概念具有层次结构 [1] 。这些语义概念可以从低级到高级可列举为:像素级元素、物体纹理和形状、物体部件、物体本身。因此,我们将Distribution Shift分为五种情况:Background Shift,Corruption Shift,Texture Shift,Destruction Shift和Style Shift,如图1所示。
具体地,我们从每种 Distribution Shift 种选取一系列具有代表性的样本用于计算 PAD 值,并按照升序排列,在图中从内环向外环展示。细节部分请查看原文。
-
Background Shift 在图像分类任务中,图像背景通常被视为标签分配过程中的辅助线索。然而,以往的工作已经证明,背景可能在模型结果预测中占主导地位,而这是我们希望避免发生的。因此我们关注模型对背景变化的不变性,从而定义Background Shift。Background Shift探究中使用了 ImageNet-9 [2] 数据集。
-
Corruption Shift Corruption的概念是在 [3] 中提出的,它代表那些在图像中自然出现的局部杂质。这些Corruption要么来自拍摄阶段的环境影响,要么来自图像处理阶段。我们将这些情况定义为Corruption Shift,即使它只影响到物体像素级的元素,却仍然会导致模型的性能明显下降。ImageNet-C [3] 被用于检验Corruption Shift下的泛化能力。
-
Texture Shift 一般来说,图像的纹理给了我们颜色或密度的空间排列信息,这对于模型获得正确的预测是至关重要的。因此,对物体纹理的替换会影响模型的预测。我们将这些变化定义为texture shift。我们使用Cue Conflict Stimuli和Stylized-ImageNet [4] 来研究texture shift下的泛化能力。
-
Destruction Shift 对应于将整个物体分解成碎片。例如,random patch-shuffle 过程将图像分割成若干个正方形的patch,并随机打乱这些 patch 的位置。这个过程可以破坏物体的全局信息,并且随着分割数目的增加,破坏程度也会增加。另外,我们将每个patch进一步划分为两个直角三角形,并分别随机排序两种类型的三角形。我们将此过程命名为triangular patch-shuffle。
-
Style Shift 通常情况下,风格是一个复杂的概念,由描述艺术品的特征所决定,如形式、颜色、构成等。风格的变化往往体现在多个概念层次上,包括纹理、形状、对象部分等。例如,对比一个简笔画和相应的照片,我们可以观察到纹理和颜色的差异,以及一些不重要的物体部分在简笔画被忽略的情况。ImageNet-R [5] 和 DomainNet [6] 用于style shift的探究。
三、Vision Transformer泛化性系统研究及与CNN的比较
评测指标
假设图像分类模型包含特征提取器 和分类器 , 训练集为 。分 别 引 入独 立同 分布 的 测试 集 和 out-of-distribution 数 据 集 。则本文采用的评测指标包含:
-
OOD Accuracy. 即模型在 out-of-distribution 数据集上的正确率:
-
IID/OOD Generalization Gap. 本文中,我们同样关注模型在out-of-distribution数据上相对于独立同分布数据的表现差距,因此我们使用了IID/OOD generalization gap作为评测指标:
结果分析
第一列图和第二列图分别表示不同模型在ImageNet-9和ImageNet-C数据集上的OOD Accuracy以及IID/OOD Generalization Gap结果。从左二图中,我们可以总结出以下结论:
-
Vision Transformer相对会更少地将背景用于类别预测,而且这种性质并不是由训练过程中复杂的Augmentation带来的。
-
Vision Transformer越大,越会将更多注意力放在前景上,从而学到了更加和背景无关的表征。
从右二图中我们可以得到结论:
-
更大的Vision Transformer在此情况下泛化性能更好。
-
Vision Transformer训练过程中是用的patch尺寸并不会影响模型在IID和OOD数据上的差距,而是影响模型在IID数据上的泛化能力。
左二图从分别表示不同模型在Stylized-ImageNet和ImageNet-R数据集上的OOD Accuracy以及IID/OOD Generalization Gap结果,最右表示,不同模型在Cue Conflict Stimuli上的表现。从Stylized-ImageNet和Cue Conflict Stimuli的结果可以得到结论:
-
Vision Transformer面对texture shift表现更好。
-
Vision Transformer越大,对全局语义信息的关注越多,受局部纹理变化的影响越小。
-
使用更大patch尺寸进行训练的 Vision Transformer对局部的纹理信息的依赖更小,更关注全局信息。
从ImageNet-R的结果中可以关注到大部分 Vision Transformer在OOD Accuracy上表现的比BiT好,但是它们的IID/OOD generalization gap差异很小。
上排2图及下排2图分别表示不同模型在不同的 Shuffle Patch 尺寸下的 patch-Shuffled 和 Triangular Patch-shuffled ImageNet过程的OOD Accuracy以及IID/OOD Generalization Gap结果。实验中使用的图像尺寸为 。可以看出,在用于随机排序的patch尺寸减小到小于训练patch尺寸时,所有Vision Transformer 的泛化性能会从很好的表现快速崩溃。从此我们可以总结出 Vision Transformer 从 Position Embedding 中获取的空间信息,相对于 CNN 来说,依赖较少。因此其直到每个处理的 patch 内部的像素相对关系收到了破坏后,性能会快速崩溃。
分别列举了DeiT-B/16,DeiT-S/16,BiT和 的结果。从结果中我们可以总结到:
-
在类似的模型参数规模下,DeiT-S/16因为其在IID数据上的优秀表现,从而在OOD上的表现也明显优于BiT模型。
-
但当观察IID/OOD Generalization Gap结果时,我们可以看出在一部分情况下(clipart和painting),DeiT-S/16并没有比BiT表现更好。
从图3-6中的五类Distribution Shift下进行的结果分析可以得出以下结论:
-
在大多数Distribution Shift下,Vision Transformer的泛化能力优于CNN。在相同或更少的参数量下,在大多数类型的Distribution Shift下,Vision Transformer的 top-1 正确率比对应的CNN领先5%以上。
-
随着Vision Transformer模型结构逐渐变大,它面对正常分布(in-distribution)数据和Out-of-distribution数据的泛化性能差距会逐渐缩小。
四、泛化增强的Vision Transformer
所有网络均包含一个Vision Transformer 作为特征提取器以及分类器 。在该部分的设定 下,模型的输入数据包含有标注的源域数据和未标注的目标域数据。
-
左上角:T-ADV通过引入域判别器 进行域对抗训练,从而促使网络学到域不变的数据表征。
-
右上角:T-MME利用目标域数据的条件熵上的极大极小过程来减小分布差距,同时学习到具有判别性的特征。该模型使用了基于余弦相似度的分类器结构 ,用于生成各个prototype。
-
下图:T-SSL是一个基于原型的自监督端到端学习框架。该框架中使用了两个memory bank 和 来计算聚类中心。该框架同样使用了基于余弦相似度的分类器结构 。
基于对抗学习的T-ADV
为了学习到域不变的表征,我们引入了一个域判别器[7],通过对抗训练的方式引导特征提取模块达成目标。如图6左上角所示,网络包含了一个Vision Transformer 作为共享的特征提取器,后面两个分支分别为分类器 和域判别器 。特征提取器旨在同时最小化两个域上的domainconfusion loss 以及有标注的源域上的label prediction loss ; 与此同时,域判别器旨在最大化 domain confusion loss 。整个框架的优化目标为:
其中 和 分别表示物体类别标签和二分类的域类别标签。 和 分别代表Softmax和交叉嫡。 是一个迭代变换的平衡系数,迭代过程遵循 [7]中的方法。
为了简化训练过程,我们同样使用了Gradient Reversal Layer (GRL)来实现不同模块间的梯度翻转。
基于信息论的T-MME
为了减小两个域在表征上的分布偏差,同时学习到具有判别性的特征,我们对目标域的条件商使用了[8]中的极大极小交替优化过程。该方法的pipeline如图6右上角所示,该模型使用了基于余 和温度系数 组成, 其输入为 正则化的特征 , 输出 。其核心思想是将prototype与相邻的未标记目标域样本之间的距离最小化,从而提取出具有判别性的目标特征。为了克服源域标注数据对prototype的支配作用,我们使用最大化目标域输出的嫡 的方法来促使原型向目标域数据靠近。与此同时,特征提取器以最小化未标记样本的商为目标,使其更好地聚集在prototype周围。因此,在权重向量和特征提取器之间形成一个极大极小优化过程。另外,目标函数同时包含了标注数据的label prediction loss 。因此整体的优化目标为:
其中 表示嫡, 是一个平衡系数。
基于自监督学习的T-SSL
我们谈久了端到端的prototypical自监督学习框架[9]在Vision Transformer的应用效果,如图6下图所示,该框架同样使用了 基于余弦相似度的分类器结构 。其首先将数据中的语义结构嵌入到一个空间中。接下来,分别在源域和目标域上使用ProtoNCE [10]。具体地,该框架中使用了两个memory bank 和 来存储两个域中每个样本的特征向量并在每个batch后进行动量更新,接着使用 means计算聚类中心 和 , 用于计算当前batch的源域 向量 和prototype 的相似性分布向量 , 其中 。则 in-domain prototypical self-supervision loss 为:
其中 和 分别代表两域数据族的编号, 代表集合的势。
此外,由于我们期望网络能产生高置信度和多样化的预测,因此我们设计了一个目标,即最大限度地增加输入图像和网络预测之间的互信息。该目标分为两项,网络预测数学希望的熵最大化和网络输出的熵最小化。因此该部分的目标为:
最后一项目标为标注数据的label prediction loss :
总体的优化目标为:
其中 和 代表平衡系数。
五、泛化增强的Vision Transformer的效果检测
我们对比了三种提升泛化能力的Vision Transformer和对应的CNN模型。从结果中可以总结到:
-
使用泛化增强方法后,Vision Transformer面对out-of-distribution数据的表现提升了4%。
-
三种泛化增强方法对Vision Transformer带来的增益基本相同。
-
越大的Vision Transformer从泛化增强方法中获得的增益越多。
从左至右分别表示T-ADV,T-MME和T-SSL在源域和目标域的训练曲线。从结果中我们可以总结出CNN上使用的传统的训练策略(绿线)不适用于Vision Transformer,需要更平滑的训练策略(红线)来促使模型更好地对齐两域的特征。具体地,相较于传统的训练策略,取消多样化的Auto Augmentation数据预处理过程以及对各种方法中用于提升泛化能力的loss项使用Adaptive渐增的系数更新方法,均有助于提升训练过程中的稳定性以及收敛的结果。
表格1中总结了泛化增强模型的实验结果。从结果中可以总结到:
-
使用泛化增强方法后,Vision Transformer面对Out-of-Distribution数据的表现提升了4%。
-
三种泛化增强方法对 Vision Transformer 带来的增益基本相同。
-
越大的 Vision Transformer 从泛化增强方法中获得的增益越多。
图7为对比实验,展示了不同训练策略对泛化增强的Vision Transformer的影响。绿线代表使用CNN上的传统训练策略,其他两个代表更平滑的策略。从这些策略的比较中可以看出,目前普遍使用的 Auto Augmentatipn 会导致T-ADV的性能下降,而对T-MME和T-SSL的影响很小。同时,平滑学习策略对于Vision Transformer 收敛具有重要意义,特别是在对抗训练模式下。对于T-MME和T-SSL,loss的平滑性也显著提高了性能。基于这些观察,我们得出结论,Vision Transformer 比相应的CNN模型对梯度更敏感,从而证明我们设计的改进的平滑训练策略的必要性。
六、总结
本文中,我们提供了关于Vision Transformer在Out-of-Distribution数据上泛化能力的全面研究,并作出了以下贡献:
-
我们根据图像中改变的语义概念,对Distribution Shift进行了分类。
-
我们提供了对Vision Transformer在五种类别的Distribution Shift下的Out-of-distribution泛化能力进行了详尽研究。
-
我们分别通过设计基于对抗训练,信息论以及自监督学习的泛化增强Vision Transformer来进一步提升模型泛化能力,并使用了平滑的训练策略以适应Vision Transformer。
我们的工作只是一个早期的尝试,因此对于开发更强大的泛化增强Vision Transformer还有很大的空间。
参考文献
Illustration by Aleksey Chizhikov from Icons8
如果觉得有用,就请分享到朋友圈吧!
公众号后台回复“CVPR21检测”获取CVPR2021论文下载~

# CV技术社群邀请函 #
备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)
即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群
每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~

