
极市导读
NN乘个常数就能在回归任务上大分了?来看看怎么做到。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
最近在一个预测1-5的回归任务数据集上发现,如果给神经网络训练时候用sigmoid+乘一个常数,替换掉原来的直接预测结果,在本地交叉验证和测试集都能提分很多。
具体的代码大概长这样,中间就是那行代码,相当于模型输出乘个6。
x = layers.Dense(6, activation="sigmoid")(weighted_layer_pool)#这里是sigmoid
output = layers.Rescaling(scale=6.0)(x) #关键
model = tf.keras.Model(inputs=[input_ids, attention_masks],outputs=output)
这个东西的作用有多大呢,大概就是单模型和模型融合的差距。
对比原来的代码是:
x = layers.Dense(6, activation="linear")(weighted_layer_pool) # 这里是linear
model = tf.keras.Model(inputs=[input_ids, attention_masks],outputs=output)
你肯定想说,胡说八道。
当然不是所有场景都有用,我们慢慢看来。
这个数据集就是正在进行的比赛feedback的第三届,属于ai4education领域的文本数据,具体任务是给学生作文作文打分,分1-5,为刻度,评价指标用5项打分维度MSE的平均,一个典型的回归任务。
链接:https://www.kaggle.com/competitions/feedback-prize-english-language-learning/data
这个trick也不是第一次见了,不过他属于包大人武器箱里压箱底的那个,因为适用范围比较小,适合有比较多0-1这个范围之外的回归任务,尤其是打车平台ETA任务,送餐ETA,销量预测等。
具体来说,就是把回归值变换到0-1之前,然后反变换回去。
不够优雅。
在一些时候我们其实可以写进神经网络里。
回到题目,你就明白为什么NN乘个常数就能在回归任务上大分了。
我们先分析下在NN里乘个常数会发生什么。
我们假设参数初始化都是均值为0方差为1的正态分布。以最简单的NN为例子
y=sigmiod(kx+b)
乘常数m,则有
y‘=sigmoid(kx+b)*m
我们对x求导,假设原来的梯度是g,乘m后的梯度就是m*g,
现象一,就是原来梯度的放缩,相当于学习率乘个m
另外,别忘了我们的初始化,我们初始化结果sigmoid出来的分布,大家可以算一下,大概就是集中在0.5附近,这样经过放缩后的m差不多就是0.5m附近。
如果原来是一个1-5的回归任务,这样初始值就会落到3,对于原来的任务,大部分人都是3分。
现象二,对于回归任务来说,这就是一个比较好的初始解。
假设我们不这么做的话,直接用NN拟合1-5的分数分布,然后发现,你要强行把NN从y=0值附近拉过去,开始计算的残差就很大,模型只好一个劲猛学亚马逊。
这两者的区别明白了吧,一个是在解空间不错的位置加速学习,一个是比较差的起点墨迹。
这算是把人的先验加到了神经网络。更别提很多任务都让你回归到1000+了。
当然实际应用的时候,注意学习率调小一点。
除了简单的没有bias项的乘数,我们还可以搞个kx+b,log,sqrt等。
x = layers.Dense(6, activation="sigmoid")(weighted_layer_pool)
output = layers.Rescaling(scale=4.0, offset=1)(x) #关键
model = tf.keras.Model(inputs=[input_ids, attention_masks],outputs=output)
有时候加进NN里不一定适用,比如值域在0-10000的时候,你肯定不想让学习率变5000倍吧。
有没有既要又要的办法?
可以直接先对目标变化,然后配对反变换,测试效果和目标预处理变换和目标反变换相差不大。这样这个Trick,也可以用在xgb和lgb这些树模型里了。
另外,一个类似的简单面试题,sum pool和avg pool在神经网络中有啥区别?
评论区给答案吧。

公众号后台回复“速查表”获取
21张速查表(神经网络、线性代数、可视化等)打包下载~
技术干货:超简单正则表达式入门教程|22 款神经网络设计和可视化的工具大汇总
极视角动态:芜湖市湾沚区联手极视角打造核酸检测便民服务系统上线!|青岛市委常委、组织部部长于玉一行莅临极视角调研
# CV技术社群邀请函 #
备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)
即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群
极市&深大CV技术交流群已创建,欢迎深大校友加入,在群内自由交流学术心得,分享学术讯息,共建良好的技术交流氛围。
点击阅读原文进入CV社区
获取更多技术干货

