pytorch标准化(Batch Normalize)中计算数据集的均值和方差

1,浅谈深度学习训练中数据规范化(Normalization)的重要性

为什么要Normalization? 简要来讲,在pytorch中,有些模型是通过规范化后的数据进行训练的,所以我们在使用这些预训练好的模型的时候,要注意在将自己的数据投入模型中之前要首先对数据进行规范化。

具体请参考 浅谈深度学习训练中数据规范化(Normalization)的重要性 - Oldpan的个人博客这篇文章讲得不错,值得好好看看。

简要的讲,BN层存在的意义(原因)在于:

随着网络的深度增加,每层特征值分布会逐渐的向激活函数的输出区间的上下两端(激活函数饱和区间)靠近,这样继续下去就会导致梯度消失。BN就是通过方法将该层特征值分布重新拉回标准正态分布,特征值将落在激活函数对于输入较为敏感的区间,输入的小变化可导致损失函数较大的变化,使得梯度变大,避免梯度消失,同时也可加快收敛。

其作用可以归纳为以下几点:

1,减少梯度消失,加快了收敛过程。

2,起到类似dropout一样的正则化能力,一定程度上防止过拟合。

3,放宽了一定的调参要求(放宽对参数初始值的要求)。

4,可以替代LRN。

2,代码实现

网上有计算数据集的均值和方差的代码,但是运行起来结果不正确,好多个博客都是人云亦云,同一份代码来回抄袭,实用性差。本人实现并测试了计算数据集的均值和方差的代码,结果是正确的。

上货!

import numpy as np
import cv2
import random
import os

# 挑选多少图片进行计算,在这里不挑选,而是计算全部的。
# 如果你有1万张,可以让CNum = 2000,随机选2000张计算。
CNum = 0   
#遍历某目录下所有的图片文件
root_path = "C:/Users/54010/Desktop/tmp"
filelist = os.listdir(root_path)  # 列出文件夹下所有的目录与文件
imgPath_list = []
for file in filelist:
    fpath = root_path + '/' + file
    # 做判断时需要传入完整文件路径
    if (os.path.isfile(fpath) and file.endswith(".png")):
        imgPath_list.append(fpath)
        CNum += 1 # 如果随机选取一部分图片计算,就注释掉本句

# shuffle , 随机挑选图片
random.shuffle(imgPath_list)

img_h, img_w = 300, 300 # 数据集中的图像尺寸
imgs = []
means, stdevs = [], []

for i in range(CNum):
    print("i=",i,imgPath_list[i])
    img = cv2.imread(imgPath_list[i])
    tmpImg = img.flatten() #多维的数组降为1维
    print("np.mean",np.mean(tmpImg))
    print("np.std", np.std(tmpImg))
    # img = cv2.resize(img, (img_h, img_w), interpolation=cv2.INTER_AREA)
    # cv2.imshow(str(i), img)
    # cv2.waitKey(0)
    img = img / 255.0 # 转换到【0,1】区间
    imgs.append(img)

imgs_numpy = np.array(imgs)
print("imgs.shape:",imgs_numpy.shape) # ((6, 300, 300, 3)) 按照 N  W  H  C 顺序
for i in range(3):#遍历每一个通道 B  G  R
    pixels = imgs_numpy[:, :, :, i].flatten() # 拉成一行
    means.append(np.mean(pixels))
    stdevs.append(np.std(pixels))

# cv2 读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转
means.reverse() # BGR --> RGB
stdevs.reverse()

print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))
print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))


print("program over!!")

Guess you like

Origin blog.csdn.net/thequitesunshine007/article/details/120488650