pytorch 函数

torch.nn.BatchNorm2d 函数

什么是batch?   

 batch是整个训练集中的一部分,由于训练集往往过大不能一次性全部输入到网络中,所以需要分批次地输送所以每一批就是一个batch(批)

什么是Normalization? 

Normalization翻译是归一化,归一化的引入是为了减少internal covariatie shift现象,其具体表现是在训 练深层网络的过程中,前面层参数的变化会影响后面每层的参数分布,

导致了训练时只能选用较低的学 习速率以及小心谨慎的参数初始化。而Batch Normalization层(BN)的引入允许我们使用更高的学习率以及不用太担心参数初始化的问题

Batch Normalization的具体过程:

对于输入的一个batch数据,先计算出每个维度的平均值和方差,对于一个2*2的灰度图像来说,那么就要计算一个batch内每个像素通道的平均值和方差(共计算四次,得到四个平均值和方差)。

然后通过以下公式得到归一化之后的batch

注意:在测试阶段计算平均值和方差有两种模式:

第一种:通过训练阶段大量batch计算得到的平均值和方差的统计值来代替测试阶段的均值和方差

第二种:通过跟随测试阶段batch的平均值和方差来对第一种方法得到的均值和方差来进行修改

running_mean = momentum * running_mean + (1 - momentum) * train_mean

running_var = momentum * running_var + (1 - momentum) * train_var

其中momoentum为权重,train_mean是训练过程中所有batch的mean的统计量,running_mean是测试batch的简单平均

torch.nn.BatchNorm2d函数

形式:torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

返回一个shape与num_features相同的tensor

其中:

1.num_features为输入batch中图像的channle数(按每一个channle来做归一化)

2.eps是一个稳定系数

3.momentum为running_mean和running_var的权重

4.affine代表着是否可学习,True代表通过学习而来,False代表是固定值

5.track_running_stats代表测试阶段使用第一种还是第二种均值测试方法,True代表第二种,False代表第一种

 

 

猜你喜欢

转载自www.cnblogs.com/tingtin/p/12425849.html