Use torch.mean() to calculate the mean and standard deviation of the image dataset

This article refers to the methods of youtube blogger Aladdin Persson and Kaggle user Afia Ibnath

torch.mean()

Pytorch Official Documentation
insert image description here
Pytorch Official Documentation

parameter:

input: input, receive tensor type
dim: dimension

Example:

insert image description here
Output result:

dim*
It can be seen from the results
that dim defaults to calculating all averages.
When dim = 1, average across columns.
When dim = 0, average across rows
*

Calculate the variance:

def get_mean_std(loader): 
    # var[X] = E[X**2] - E[X]**2 方差公式, var[]代表方差,E[]表示期望(平均值)
    channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
    for data, _ in tqdm(loader):
        channels_sum += torch.mean(data, dim = [1, 2]) 
        channels_sqrd_sum += torch.mean(data ** 2, dim = [1, 2])
        num_batches += 1


    mean = channels_sum / num_batches
    std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5

    return mean, std, num_batches



This function can be used to calculate the mean and variance of an RGB three-channel dataset.
Plug and play, very convenient.

Running results
insert image description here
This is to calculate the mean and variance of the cat and dog training set
Training data set:
https://www.kaggle.com/tongpython/cat-and-dog/code

Notice

1. The image received by the function is of tensor type, and the image type needs to be converted to (channel, H, W) in advance.
2. If the input image format is (batchsize, channel, H, W), the parameter dim needs to be changed to [0, 2 , 3].

This is the first time I have published an article. If there is something wrong, I hope everyone can criticize and correct it. Without permission, please do not reprint!

Guess you like

Origin blog.csdn.net/x5445687d/article/details/120245866