大数跨境
0
0

对比学习损失与交叉熵损失的联系,以及温度系数的作用

对比学习损失与交叉熵损失的联系,以及温度系数的作用 极市平台
2022-10-04
0
↑ 点击蓝字 关注极市平台

作者丨Youngshell@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/506544456
编辑丨极市平台

极市导读

 

详解对比学习的损失函数InfoNCE loss和cross entropy loss以及温度系数。>>加入极市CV技术交流群,走在计算机视觉的最前沿

导读

在文章 《对比学习(Contrastive Learning),必知必会》《CIKM2021 当推荐系统遇上对比学习,谷歌SSL算法精读》 中,我们都提到过两个思考:

(1)对比学习常用的损失函数InfoNCE loss和cross entropy loss是否有联系?

(2)对比损失InfoNCE loss中有一个温度系数,其作用是什么?温度系数的设置对效果如何产生影响?

个人认为,这两个问题可以作为对比学习相关项目面试的考点,本文我们就一起盘一盘这两个问题。

1. InfoNCE loss公式

对比学习损失函数有多种,其中比较常用的一种是InfoNCE loss,InfoNCE loss其实跟交叉熵损失有着千丝万缕的关系,下面我们借用恺明大佬在他的论文MoCo里定义的InfoNCE loss公式来说明。论文MoCo提出,我们可以把对比学习看成是一个字典查询的任务,即训练一个编码器从而去做字典查询的任务。假设已经有一个编码好的query (一个特征), 以及一系列编码好的样本 , 那么 可以看作是字典里的key。假设字典里只有一个 (称 为 positive) 是跟 是匹配的,那么 +就互为正样本对, 其余的 的负样本。一旦定义 好了正负样本对, 就需要一个对比学习的损失函数来指导模型来进行学习。这个损失函数需要满足 这些要求, 即当query 和唯一的正样本 相似, 并且和其他所有负样本key都不相似的时候, 这 个loss的值应该比较低。反之, 如果 不相似, 或者 和其他负样本的key相似了, 那么loss就 应该大, 从而惩罚模型, 促使模型进行参数更新。

MoCo采用的对比学习损失函数就是InfoNCE loss,以此来训练模型,公式如下:

2. InfoNCE loss和交叉熵损失有什么关系?

我们先从softmax说起,下面是softmax公式:

交叉熵损失函数如下:

在有监督学习下, ground truth是一个one-hot向量, softmax的结果 取- , 再与ground truth相乘之后, 即得到如下交叉熵损失:

上式中的 在有监督学习里指的是这个数据集一共有多少类别, 比如CV的ImageNet数据集有 1000 类, k就是1000。

对于对比学习来说,理论上也是可以用上式去计算loss,但是实际上是行不通的。为什么呢?

还是拿CV领域的ImageNet数据集来举例,该数据集一共有128万张图片,我们使用数据增强手段(例如,随机裁剪、随机颜色失真、随机高斯模糊)来产生对比学习正样本对,每张图片就是单独一类,那k就是128万类,而不是1000类了,有多少张图就有多少类。但是softmax操作在如此多类别上进行计算是非常耗时的,再加上有指数运算的操作,当向量的维度是几百万的时候,计算复杂度是相当高的。所以对比学习用上式去计算loss是行不通的。

怎么办呢?NCE loss可以解决这个问题。

NCE(noise contrastive estimation)核心思想是将多分类问题转化成二分类问题,一个类是数据类别 data sample,另一个类是噪声类别 noisy sample,通过学习数据样本和噪声样本之间的区别,将数据样本去和噪声样本做对比,也就是“噪声对比(noise contrastive)”,从而发现数据中的一些特性。但是,如果把整个数据集剩下的数据都当作负样本(即噪声样本),虽然解决了类别多的问题,计算复杂度还是没有降下来,解决办法就是做负样本采样来计算loss,这就是estimation的含义,也就是说它只是估计和近似。一般来说,负样本选取的越多,就越接近整个数据集,效果自然会更好。

NCE loss常用在NLP模型中,公式如下:

上述公式细节详见:NCE loss(https://arxiv.org/pdf/1410.8251.pdf)

有了NCE loss,为什么还要用Info NCE loss呢?

Info NCE loss是NCE的一个简单变体,它认为如果你只把问题看作是一个二分类,只有数据样本和噪声样本的话,可能对模型学习不友好,因为很多噪声样本可能本就不是一个类,因此还是把它看成一个多分类问题比较合理(但这里的多分类kk指代的是负采样之后负样本的数量,下面会解释)。于是就有了InfoNCE loss,公式如下:

上式中, 是模型出来的logits, 相当于上文 oftmax公式中的 是一个温度超参数, 是个标 量, 假设我们忽略 , 那么infoNCE loss其实就是cross entropy loss。唯一的区别是, 在cross entropy loss里, 指代的是数据集里类别的数量, 而在对比学习InfoNCE loss里, 这个k指的是负样本的数量。上式分母中的sum是在 1 个正样本和 个负样本上做的, 从0到 , 所以共 个样本, 也就是字典里所有的key。恺明大佬在MoCo里提到, InfoNCE loss其实就是一个cross entropy loss, 做的是一个 类的分类任务, 目的就是想把 这个图片分到 这个类。

另外,我们看下图中MoCo的伪代码,MoCo这个loss的实现就是基于cross entropy loss。

3. 温度系数的作用

温度系数 虽然只是一个超参数, 但它的设置是非常讲究的, 直接影响了模型的效果。上式Info NCE loss中的 相当于是logits, 温度系数可以用来控制logits的分布形状。对于既定的logits分 布的形状, 当 值变大, 则 就变小, 则会使得原来logits分布里的数值都变小, 且经过指数运算之后, 就变得更小了, 导致原来的logits分布变得更平滑。相反, 如果 取得值小, 就 变大, 原来的logits分布里的数值就相应的变大, 经过指数运算之后, 就变得更大, 使得这个分布变得更集中, 更peak。

如果温度系数设的越大,logits分布变得越平滑,那么对比损失会对所有的负样本一视同仁,导致模型学习没有轻重。如果温度系数设的过小,则模型会越关注特别困难的负样本,但其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差。

总之,温度系数的作用就是它控制了模型对负样本的区分度。

参考

[1]Momentum Contrast for Unsupervised Visual Representation Learning.

[2]https://www.bilibili.com/video/BV1C3411s7t9 (墙裂推荐!bryanyzhu大佬出品~)

极市干货
算法竞赛:国际赛事证书,220G数据集开放下载!ACCV2022国际细粒度图像分析挑战赛开赛!
技术综述BEV 学术界和工业界方案、优化方法与tricks综述PyTorch下的可视化工具(网络结构/训练过程可视化)
极视角动态:极视角与华为联合发布基于昇腾AI的「AICE赋能行业解决方案」算法误报怎么办?自训练工具使得算法迭代效率提升50%!

CV技术社群邀请函 #



△长按添加极市小助手
添加极市小助手微信(ID : cvmart2)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


极市&深大CV技术交流群已创建,欢迎深大校友加入,在群内自由交流学术心得,分享学术讯息,共建良好的技术交流氛围。

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读13.3k
粉丝0
内容8.2k