PyTorch implementation to calculate mean and standard deviation of image dataset

1. Implementation process

When using Pytorch for preprocessing, the torchvision.transforms.Normalize(mean, std) method is usually used for data normalization, where the parameters mean and std represent the mean and standard deviation sequence of each channel of the image set, respectively.
First, the definitions of mean and std are given. The mathematical representation is as follows:
Suppose there is a set of data sets X i , i ∈ { 1 , 2 , ⋯ , n } X_i,\,\,i\in\{1,2,\cdots ,n\}Xi,i{ 1,2,,n } , then the mean of this set of data sets is: mean = ∑ i = 1 n X in (1) mean=\frac{\displaystyle\sum_{i=1}^nX_i}{n}\tag{1}mean=ni=1nXi( 1 ) usually useX ‾ \overline XXrepresents the mean of the data.
The standard deviation of this data set is: std = ∑ i = 1 n ( X i − X ‾ ) 2 n = ∑ i = 1 n ( X i 2 − 2 X i X ‾ + X ‾ 2 ) n = ( ∑ i = 1 n X i 2 ) − n X ‾ 2 n = ∑ i = 1 n X i 2 n − X ‾ 2 (2) std=\sqrt{\frac{\displaystyle\sum_{i=1}^n \left(X_i-\overline X\right)^2}{n}}\\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^n(X_i^2-2X_i\overline X+ \overline X^2)}{n}}\\[2ex]=\sqrt{\frac{\left(\displaystyle\sum_{i=1}^nX_i^2\right)-n\overline X^2} {n}}\\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^nX_i^2}{n}-\overline X^2}\tag{2}std=ni=1n(XiX)2 =ni=1n(Xi22 XiX+X2) =n(i=1nXi2)nX2 =ni=1nXi2X2 ( 2 ) The function code for calculating the mean and standard deviation of each channel of the image dataset is given below:

import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader

batch_size = 64

# 训练集(以CIFAR-10数据集为例)
train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

def get_mean_std_value(loader):
    '''
    求数据集的均值和标准差
    :param loader:
    :return:
    '''
    data_sum,data_squared_sum,num_batches = 0,0,0

    for data,_ in loader:
        # data: [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算
        data_sum += torch.mean(data,dim=[0,2,3])    # [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算
        data_squared_sum += torch.mean(data**2,dim=[0,2,3])  # [batch_size,channels,height,width]
        # 统计batch的数量
        num_batches += 1
    # 计算均值
    mean = data_sum/num_batches
    # 计算标准差
    std = (data_squared_sum/num_batches - mean**2)**0.5
    return mean,std

mean,std = get_mean_std_value(train_loader)
print('mean = {},std = {}'.format(mean,std))

The mean and standard deviation of the CIFAR10 dataset are:

mean = tensor([0.4914, 0.4821, 0.4465]),std = tensor([0.2470, 0.2435, 0.2616])

The mean and standard deviation of the MNIST dataset are:

mean = tensor([0.1307]),std = tensor([0.3081])

2. References

[1] https://zhuanlan.zhihu.com/p/378810257

Guess you like

Origin blog.csdn.net/weixin_43821559/article/details/123459085