大数跨境
0
0

NeurIPS 22|即插即用!数据集蒸馏新方法HaBa:显著提升下游模型性能

NeurIPS 22|即插即用!数据集蒸馏新方法HaBa:显著提升下游模型性能 极市平台
2023-01-12
0
↑ 点击蓝字 关注极市平台
作者丨GlobalTrack
编辑丨极市平台

极市导读

 

本文方法将合成数据集分解为两个部分:数据幻觉器网络和基础数据。数据幻觉器网络将基础数据作为输入,输出幻觉图像(合成图像)。该方法得到的合成数据集在跨架构任务中比基准方法取得了精度10%的提升。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

论文链接: https://openreview.net/pdf?id=luGXvawYWJ

代码链接: https://github.com/Huage001/DatasetFactorization

简介

深度学习取得了巨大成功,训练一般需要大量的数据。存储、传输和数据集预处理成为大数据集使用的阻碍。另外发布原始数据可能会有隐私版权等问题。

数据集蒸馏(Dataset Distillation)是一种解决方案,通过蒸馏一个数据集形成一个只包含少量样本的合成数据集,同时训练成本显著降低。数据集蒸馏可以用于持续学习、神经网络架构搜索等领域。

最早提出的数据集蒸馏算法核心思想即优化合成数据集,在下游任务中最小化损失函数。DSA( Dataset condensation with differentiable siamese augmentation)、GM( Dataset condensation with gradient matching)、CS(Dataset condensation with contrastive signals)等方法提出匹配真实数据集和合成数据集的梯度信息的算法。MTT(Dataset distillation by matching training trajectories)指出由于跨多个步骤的误差累计,单次迭代的训练误差可能导致较差的性能,提出在真实数据集上匹配模型的长期动态训练过程。除了匹配梯度信息的方法,DM(Dataset condensation with distribution matching)提出了匹配数据集分布,具体方法是添加最大平均差异约束( Maximum Mean Discrepancy,MMD)。

本文方法将合成数据集分解为两个部分:数据幻觉器网络(Data Hallucination Network)和基础数据(Bases)。数据幻觉器网络将基础数据作为输入,输出幻觉图像(合成图像)。在数据幻觉器网络训练过程中,本文考虑添加特殊设计的对比学习损失和一致性损失。本文方法得到的合成数据集在跨架构任务中比基准方法取得了精度10%的提升。

方法

传统的数据集蒸馏方法将合成样本独立处理,忽略了不同样本间的内部关系,可能导致较差的数据效率。本文方法提出将数据集蒸馏定义为包含 个幻觉器和 个基础样本的幻觉器-基(hallucinator-basis)的分解问题:

训练过程时,训练数据通过传入第 个基础数据在线生成。合成数据可以表示为:

基与幻觉器

先前数据集蒸馏方法中,为了在下游模型中输入和输出的形状保持一直,合成数据的形状需要与真实数据相同。由于幻觉器网络可以使用空间和通道变换,本文方法没有形状相同限制。

给了基数据 ,一个幻觉器网络,目标创建一个输出 。该任务可以视作为一个条件图像生成问题。借鉴于风格迁移任务,本文的幻觉器网络设计为encoder-transform-decoder架构。编码器由若干卷积层组成,将输入非线性映射。之后经过尺度 和位移 的仿射变化。 是网络参数。解码器是和编码器对称的CNN网络将特征映射到图像空间。

对抗性对比约束

本文的幻觉器网络训练过程是一个最小-最大博弈(min-max game)过程。最大化过程即最大化不同幻觉器间的差异。输入 在幻觉器最后一层的输出定义为 。损失函数类似于对比学习,可以描述为:

对于图像分类任务,

另一个损失函数关注于减少幻觉器网络输出 间的差异,核心目标是增加合成数据集的数据多样性。损失函数可以描述为:

分解训练方法

与先前的数据集蒸馏方法训练范式类似,合成数据集按照迭代算法更新。每一个迭代周期,随机选取幻觉器和基,形成若干幻觉器-基组合。训练的损失函数包含知识蒸馏损失与一致性损失:

本文的数据集蒸馏损失函数采用MTT方法。核心思想是使用训练周期为 的模型权重,使用合成数据集 训练 次,使用真实数据集 训练 次,通过损失函数使合成数据集更新的参数与真实数据集更新的参数保持一致:

实验

与SOTA方法的比较结果。比较的方法包括核心集算法(Coreset),数据集蒸馏方法(元学习方法DD、LD,训练匹配方法DC、DSA、DSA,分布匹配方法DM、CAFE)和本文方法Factorization。超参数,每一类合成样本数(IPC)[1,10,50],本文的每一类基数量(BPC)[1,9,49]。

下图给出了实验结果。可以看出本文方法取得了最高的精度,在合成数据集样本数小于1%时性能差异最为显著。

与不同合成数据集生成算法和不同卷积神经网络模型组合的比较实验。在AlexNet网络的实验中,本文的方法与MTT相比最高取得了17.57%的性能提升。

不同类别是否共享幻觉器的Ablation实验。在相同的BPC条件下,较少的合成样本数情况下不共享幻觉器的方法(w/o share)可以获得更好的性能。较多的BPC情况下,不共享幻觉器方法不能获得更好的性能。主要原因:1)共享幻觉器方法可以获得数据集的全局信息。2)不共享幻觉器的方法给优化过程较大的负担

本文方法基和幻觉器生成图像的可视化如下:

公众号后台回复“CNN综述”获取67页综述深度卷积神经网络架构

极市干货

技术干货损失函数技术总结及Pytorch使用示例深度学习有哪些trick?目标检测正负样本区分策略和平衡策略总结

实操教程GPU多卡并行训练总结(以pytorch为例)CUDA WarpReduce 学习笔记卷积神经网络压缩方法总结

极市原创作者激励计划 #


极市平台深耕CV开发者领域近5年,拥有一大批优质CV开发者受众,覆盖微信、知乎、B站、微博等多个渠道。通过极市平台,您的文章的观点和看法能分享至更多CV开发者,既能体现文章的价值,又能让文章在视觉圈内得到更大程度上的推广,并且极市还将给予优质的作者可观的稿酬!

我们欢迎领域内的各位来进行投稿或者是宣传自己/团队的工作,让知识成为最为流通的干货!

对于优质内容开发者,极市可推荐至国内优秀出版社合作出书,同时为开发者引荐行业大牛,组织个人分享交流会,推荐名企就业机会等。


投稿须知:
1.作者保证投稿作品为自己的原创作品。
2.极市平台尊重原作者署名权,并支付相应稿费。文章发布后,版权仍属于原作者。
3.原作者可以将文章发在其他平台的个人账号,但需要在文章顶部标明首发于极市平台

投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读5.7k
粉丝0
内容8.2k