大数跨境
0
0

标签平滑Label Smoothing技巧总结

标签平滑Label Smoothing技巧总结 AI算法之道
2022-03-08
0
导读:本文对Label Smoothing技术进行了简单的介绍,通过简单的例子来增加大家的直观认识,最后分享了该技巧的代码实现。





01


引言


Label Smoothing 又被称之为标签平滑,常常被用在分类网络中来作为防止过拟合的一种手段,整体方案简单易用,在小数据集上可以取得非常好的效果。

Label Smoothing 作为一种简单的训练trick,可以通过很少的代价(只需要修改target的编码方式),即可获得准确率的提升,本文就其原理和具体实现进行介绍,希望可以帮助大家理解其背后的具体原理。



02


初识


我们首先来看Label Smoothing的公式,在介绍之前我们先来观察一下传统的 one-hot encoding的公式,如下:

而Label Smoothing 导入了一个factor机制,公式改变如下:

只看公式,多少有些难懂,好嘛!我们来举个栗子瞧瞧啦。。。

不妨假设我们今天有四个类别,分别为dog,cat,bird,turtle,我们对其进行编码,即:
dog = 0 , cat = 1 , bird = 2 , turtle = 3
  •  我们采用one-hot对其进行编码,结果如下:

dog    = [ 1 , 0 , 0 , 0 ]cat    = [ 0 , 1,  0 , 0 ]bird   = [ 0 , 0 , 1 , 0 ]turtle = [ 0,  0 , 0 , 1 ]
  • 采用Label smoothing对其编码,引入一个factor来将其几率分配给其他类别 ,这里假设factor = 0.1,则生成的标签如下:

dog    = [ 0.9 , 0.03 , 0.03 , 0.03 ]cat    = [ 0.03 , 0.9 , 0.03 , 0.03 ]bird   = [ 0.03 , 0.03 , 0.9 , 0.03 ]turtle = [ 0.03 , 0.03 , 0.03 , 0.9 ]




03


深入


有了上述直观的理解,想必大家对Label Smoothing有了简单的认识,接着我们来思考这样的改变会对损失函数带来什么样的影响。

为此我们先来看一下分类任务中最常见的cross entropy损失函数,如下:

接着我们使用上述四个类别,来看看正确分类时Loss的计算,如下:

观察上图可以看出,如果整体分类呈现正常梯度下降的话,使用Label Smoothing相比不使用的loss下降相对比较小。

那反过来,如果网络越学预测效果越差呢?

通过上图可以看出,就算训练阶段预测错误时使用Label Smoothing的loss也相比之前惩罚的更小(扣得更少)。



04



实现



我们来对Label Smoothing技术,作如下总结:
  • 使用了Label Smoothing损失函数后,在训练阶段预测正确时 loss 不会下降得太快,预测错误的時候 loss 不会惩罚得太多,使其不容易陷入局部最优点,这在一定程度可以抑制网络过拟合的现象。

  • 对于分类类别比较接近的场景,网络的预测不会过于绝对,在引入Label Smoothing技巧后,通过分配这些少数的几率也可以使得神经网络在训练的时候不这么绝对。

接着,我们来用Python对其实现,代码如下:
def label_smoothing(labels, factor=0.1):    num_labels = labels.get_shape().as_list()[-1]    labels = ((1-factor) * labels) + (factor/ num_labels)    return labels




05



经验分享



在实际调参的一些经验分享如下:
  • 不管是在object detection的分类网络或者是多分类网络导入label smoothing皆有不错的效果,基本上算轻松又容易提升准确度的做法

  • 当数据量足够多的时候,Label smoothing这个技巧很容易使网络变得欠拟和。

  • factor通常设置为0.1,之前做对比实验试过使用0.2,0.3等参数,会发现皆无较好的效果,反而使网络变得难以收敛。

  • 可以利用label smoothing的特性来做点微小的改动,比如遇上相似类型的事物时,可以将factor分配给相似的类别,而不是全部类别,这通常会有不错的效果。




06


总结


本文对Label Smoothing技术进行了简单的介绍,通过简单的例子来增加大家的直观认识,最后分享了该技巧的代码实现。


您学废了吗?






点击上方小卡片关注我




万水千山总关情,点个在看行不行。

【声明】内容源于网络
0
0
AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
内容 573
粉丝 0
AI算法之道 一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
总阅读148
粉丝0
内容573