大数跨境
0
0

AAAI 2023 | 动态温度超参蒸馏新方法

AAAI 2023 | 动态温度超参蒸馏新方法 极市平台
2023-01-05
1
↑ 点击蓝字 关注极市平台
作者丨Zheng Li@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/595735843
编辑丨极市平台

极市导读

 

在蒸馏任务里,能不能让网络自己学习一个适合的动态温度超参进行蒸馏,并且参考课程学习,形成一个蒸馏难度由易到难的情况?基于此,本文提出了简单且高效的动态温度超参蒸馏新方法CTKD。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

论文题目:<Curriculum Temperature for Knowledge Distillation> (AAAI 2023)

论文地址:https://arxiv.org/abs/2211.16231

开源代码:https://github.com/zhengli97/CTKD (欢迎star)

一句话概括:

相对于静态温度超参蒸馏,本文提出了简单且高效的动态温度超参蒸馏新方法。

背景问题:

目前已有的蒸馏方法中,都会采用带有温度超参的KL Divergence Loss进行计算,从而在教师模型和学生模型之间进行蒸馏,公式如下:

其中, 温度超参 的大小控制了两个预测结果 的平滑程度, 决定了两个概率分布间的距离, 越大( , 就会使得概率分布越平滑(soft), 越小 , 越接近0, 会使得概率分布越尖锐(sharp)。τ的大小影响着蒸馏中学生模型学习的难度,不同的τ会产生不同的蒸馏结果。

而现有工作普遍的方式都是采用固定的温度超参,一般会设定成4。

那么这就带来了两个问题:

1. 不同的教师学生模型在KD过程中最优超参不一定是4。如果要找到这个最佳超参,需要进行暴力搜索,会带来大量的计算,整个过程非常低效

2. 一直保持静态固定的温度超参对学生模型来说不是最优的。基于课程学习的思想,人类在学习过程中都是由简单到困难的学习知识。那么在蒸馏的过程中,我们也会希望模型一开始蒸馏是让学生容易学习的,然后难度再增加。难度是一直动态变化的。

于是一个自然而然的想法就冒了出来:

在蒸馏任务里,能不能让网络自己学习一个适合的动态温度超参进行蒸馏,并且参考课程学习,形成一个蒸馏难度由易到难的情况?

于是我们就提出了CTKD来实现这个想法。

方法:

既然温度超参τ可以在蒸馏里决定两个分布之间的KL Divergence,进而影响模型的学习,那我们就可以通过让网络自动学习一个合适的τ来达到以上的目的。

于是以上具体问题就直接可以转化成以下的核心思想

在蒸馏过程里,学生网络被训练去最小化KL loss的情况下,τ作为一个可学习的参数,要被训练去最大化KL loss,从而发挥对抗(Adversarial)的作用,增加训练的难度。随着训练的进行,对抗的作用要不断增加,达到课程学习的效果。

以上的实现可以直接利用一个非常简单的操作:利用梯度反向层GRL (Gradient Reversal Layer)来去反向可学习超参τ的梯度,就可以非常直接达到对抗的效果,同时随着训练的进行,不断增加反向梯度的权重λ,进而增加学习的难度。

CTKD的论文的结构图如下:

Fig.1 CTKD网络结构图

CTKD方法可以简单分为左右两个部分:

  1. 对抗温度超参τ的学习部分。

这里只包含两个小模块,一个是梯度反向层GRL,用于反向经过温度超参τ的梯度,另一个是可学习超参温度τ。

其中对于温度超参τ,有两种实现方式,第一种是全局方案 (Global Temperature),只会产生一个τ,代码实现非常简单,就一句话:

self.global_T = nn.Parameter(torch.ones(1), requires_grad=True)

第二种是实例级别方案(Instance-wise Temperature),即对每个单独的样本都产生一个τ,也就是对于一个batch中128个sample,那么就生成对应128个τ。代码实现也很简单,就是两层conv组成的MLP。

两种方案的对比图如Fig.2所示。

Fig.2 两种不同的可学习温度超参实现。

2. 难度逐渐增加的课程学习部分。

随着训练的进行,不断增加GRL的权重λ,达到增加学习难度的效果。

在论文的实现里,我们直接采用Cos的方式,让反向权重λ从0增加到1。

以上就是CTKD的全部实现,非常的简单有效。

总结一下方法:CTKD总共包含两个模块,梯度反向层GRL和温度预测模块,

CTKD方法可以作为即插即用的插件应用在现有的SOTA的蒸馏方法中,取得广泛的提升。

实验结果

三个数据集:CIFAR-100,ImageNet和MS-COCO。

CIFAR-100上,CTKD的实验结果:

作为一个即插即用的插件,应用在已有的SOTA方法上:

在ImageNet上的实验:

在MS-COCO的detection实验上:

温度超参的整体学习过程可视化:

由以上图可以看到,CTKD整体的动态学习τ的过程。

将CTKD应用在多种现有的蒸馏方案上,可以取得广泛的提升效果。

欢迎大家试用~

公众号后台回复“2022盘点”即可获取2022年32篇AI热点论文PDF资源包

△点击卡片关注极市平台,获取最新CV干货

极市干货

技术干货损失函数技术总结及Pytorch使用示例深度学习有哪些trick?目标检测正负样本区分策略和平衡策略总结
实操教程GPU多卡并行训练总结(以pytorch为例)CUDA WarpReduce 学习笔记卷积神经网络压缩方法总结

极市原创作者激励计划 #


极市平台深耕CV开发者领域近5年,拥有一大批优质CV开发者受众,覆盖微信、知乎、B站、微博等多个渠道。通过极市平台,您的文章的观点和看法能分享至更多CV开发者,既能体现文章的价值,又能让文章在视觉圈内得到更大程度上的推广,并且极市还将给予优质的作者可观的稿酬!

我们欢迎领域内的各位来进行投稿或者是宣传自己/团队的工作,让知识成为最为流通的干货!

对于优质内容开发者,极市可推荐至国内优秀出版社合作出书,同时为开发者引荐行业大牛,组织个人分享交流会,推荐名企就业机会等。


投稿须知:
1.作者保证投稿作品为自己的原创作品。
2.极市平台尊重原作者署名权,并支付相应稿费。文章发布后,版权仍属于原作者。
3.原作者可以将文章发在其他平台的个人账号,但需要在文章顶部标明首发于极市平台

投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读3.2k
粉丝0
内容8.2k