
01
引言
本文重点介绍BatchNorm的定义和相关特性,并介绍了其详细实现和具体应用。希望可以帮助大家加深对其理解。
嗯嗯,闲话少说,我们直接开始吧!
02
什么是BatchNorm?
-
易于训练:由于网络权重的分布随这一层的变化小得多,因此我们可以使用更高的学习率。我们在训练中收敛的方向没有那么不稳定,这样我们就可以更快地朝着loss收敛的方向前进。 -
提升正则化:尽管网络在每个epoch都会遇到相同的训练样本,但每个小批量的归一化是不同的,因此每次都会稍微改变其值。 -
提升精度:可能是由于前面两点的结合,论文提到他们获得了比当时最先进的结果更好的准确性。
03
BatchNorm是如何工作的?
BatchNorm所做的是确保接收到的输入具有平均值0和标准偏差1。
本文中介绍的算法如下:

下面是我自己用pytorch进行的实现:
import numpy as npimport torchfrom torch import nnfrom torch.nn import Parameterclass BatchNorm(nn.Module):def __init__(self, num_features, eps=1e-5, momentum=0.1):super().__init__()self.gamma = Parameter(torch.Tensor(num_features))self.beta = Parameter(torch.Tensor(num_features))self.register_buffer("moving_avg", torch.zeros(num_features))self.register_buffer("moving_var", torch.ones(num_features))self.register_buffer("eps", torch.tensor(eps))self.register_buffer("momentum", torch.tensor(momentum))self._reset()def _reset(self):self.gamma.data.fill_(1)self.beta.data.fill_(0)def forward(self, x):if self.training:mean = x.mean(dim=0)var = x.var(dim=0)self.moving_avg = self.moving_avg * momentum + mean * (1 - momentum)self.moving_var = self.moving_var * momentum + var * (1 - momentum)else:mean = self.moving_avgvar = self.moving_varx_norm = (x - mean) / (torch.sqrt(var + self.eps))return x_norm * self.gamma + self.beta
这里对其进行补充说明如下:
我们在训练和推理过程中BatchNorm有不同的行为。在训练中,我们记录均值和方差的指数移动平均值,以供以后在推理时使用。其原因是,在训练期间处理批次时,我们可以获得输入随时间变化的均值和方差的更好估计,然后将其用于推理。在推理过程中使用输入批次的平均值和方差将不太准确,因为其大小可能比训练中使用的小得多,大数定律在这里发挥了作用。
04
什么时候使用Batchnorm ?
这似乎总是有帮助的,所以没有理由不使用它。通常它出现在全连接层/卷积层和激活函数之间。但也有人认为,最好把它放在激活层之后。我找不到任何关于激活函数之后使用它的论文,所以最安全的选择是按照每个人的做法,在激活函数前使用它。
05
总结
-
我们知道,一个已经训练的网络包含用于训练它的数据集的移动平均值和方差,这可能是一个问题。在迁移学习期间,我们通常会冻结大部分层,如果不小心,BatchNorm层也会冻结,这意味着应用的移动平均值属于原始数据集,而不是新数据集。解冻BatchNorm层是一个好主意,将允许网络重新计算自己数据集上的移动平均值和方差。
点击上方小卡片关注我
万水千山总关情,点个在看行不行。

