大数跨境
0
0

ICCV 2023|从蒸馏到自蒸馏:通用归一化损失与定制软标签

ICCV 2023|从蒸馏到自蒸馏:通用归一化损失与定制软标签 极市平台
2023-07-25
1
导读:代码已开源~
↑ 点击蓝字 关注极市平台
作者丨美索不达米亚平原@知乎
来源丨https://zhuanlan.zhihu.com/p/644157944
编辑丨极市平台

极市导读

 

文章包括了对logit蒸馏损失计算方法的改进,并基于改进后的公式提出了定制的软标签,用于实现自蒸馏。代码已开源。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

本文介绍我们ICCV 2023关于知识蒸馏的工作: From Knowledge Distillation to Self-Knowledge Distillation: A Unified Approach with Normalized Loss and Customized Soft Labels, 文章包括了对logit蒸馏损失计算方法的改进,并基于改进后的公式提出了定制的软标签,用于实现自蒸馏。现在代码已经开源,欢迎大家使用~(转载请注明出处)

文章链接:https://arxiv.org/abs/2303.13005

代码链接:https://github.com/yzd-v/cls_KD

一、简介

原生蒸馏使用教师的logits作为软标签,与学生的输出计算蒸馏损失。自蒸馏则试图在缺乏教师模型的条件下,通过设计的额外分支或者特殊的分布来获得软标签,再与学生的输出计算蒸馏损失。二者的差异在于获得软标签的方式不同。

这篇文章旨在,1)改进计算蒸馏损失的方法,使得学生能更好地使用软标签。2)提出一种通用的高效简单的方法获得更好的软标签,用于提升自蒸馏的性能和通用性。针对这两个目标,我们分别提出了Normalized KD(NKD)和Universal Self-Knowledge Distillation (USKD)。

二、方法与细节

1)NKD

表示label的值, 分类任务采用交叉熵作为模型训练的原损失:

原生的蒸馏损失表示为:

蒸馏损失的第一项和模型原损失一致, 均是关于目标类别target。而蒸馏损失的第二项则是交叉樀 的形式, 交叉樀损失的优化目标是使 接近, 观察蒸馏损失的第二项可知, 。在训练中, 学生输出的目标类别概率 是在不断变化的, 无法恰好与 相等, 这使得两个non-target logits的和不等, 阻碍了 变得接近。因此我们针对两个non-target logits进行归一化, 强制使他们相等, 提出了 Normalized KD, 用于更好地使用软标签:

2)USKD

根据NKD的公式, 我们从target和non-target两个角度来人工设计软标签, 以实现自蒸馏。首先针对target部分, 教师输出的 在训练中是稳定的, 并且反映了图片的分类难度。而在自蒸馏中, 我们能使用的只有学生输出的 , 它在训练前期的值很小, 并且不同样本间值的差异较小, 此外随着模型的训练, 变化很大。为了使得其向 接近, 我们首先对其进行平方, 以扩大样本间的值差异, 接着我们提出了一种平滑方法来控制其在不同训练阶段的相对稳定, 并获得soft target label, 用于计算NKD的第一部分target损失:

针对第二部分soft non-target labels, 其组成可以分为不同类别的顺序分布以及值的分布。首先针对顺序, 我们提出对CNN模型的第三个stage输出或者ViT模型的中间层token进行分类, 用一个小的权重对这个分类进行弱的监督, 得到weak logit , 再将 归一化后相加, 得到的顺序作为最终non-target label的顺序 :

对于值的分布, 我们采用了Zipf's LS 的做法, 并利用 进行排序, 获得soft non-target labels, 用于计算NKD的第二部分non-target损失:

三、实验

我们首先在CIFAR-100和ImageNet上对NKD进行了验证,学生更好地利用了老师的软标签soft labels,获得了更好的表现。

对于自蒸馏,我们也在CIFAR-100和ImageNet上对USKD进行了验证,并测试了自蒸馏所需要的额外训练时间,模型在很少的时间消耗下便获得了可观的提升。

我们的NKD和USKD同时适用于CNN模型与ViT模型,因此我们还在更多模型上进行了验证。

四、代码部分

我们已将代码开源:https://github.com/yzd-v/cls_KD

开源代码基于MMClassification,我们也在其中放了对应的模型,并且实现了一些其他文章,比如DKD,MGD,SRRL,WSLD,ViTKD,欢迎大家试用。

此外,我们对MMClassification的0.x大版本与1.x大版本进行了蒸馏环境的适配,供大家交流学习

公众号后台回复“极市直播”获取100+期极市技术直播回放+PPT

极市干货

极视角动态2023GCVC全球人工智能视觉产业与技术生态伙伴大会在青岛圆满落幕!极视角助力构建城市大脑中枢,芜湖市湾沚区智慧城市运行管理中心上线!
数据集:面部表情识别相关开源数据集资源汇总打架识别相关开源数据集资源汇总(附下载链接)口罩识别检测开源数据集汇总
经典解读:多模态大模型超详细解读专栏

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读3.2k
粉丝0
内容8.2k