Explain the torch.nn.MSELoss function in Pytorch in detail, including the analysis of each parameter!

1. Function introduction

The interface declaration of the MSELoss function in Pytorch is as follows, and the specific URL can be clicked here .

torch.nn.MSELoss(size_average=None, reduce=None, reduction=‘mean’)

By default, this function is used to calculate the mean of the sum of squares of the corresponding elements of the two inputs . Specifically, in deep learning, this function can be used to calculate the similarity of two feature maps.

2. How to use
import torch

# input和target分别为MESLoss的两个输入
input = torch.tensor([0.,0.,0.])
target = torch.tensor([1.,2.,3.])

# MSELoss函数的具体使用方法如下所示,其中MSELoss函数的参数均为默认参数。
loss = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
loss = loss(input, target)

print(loss)

# input和target逐元素差值平方和的均值计算如下,可以看到与上述MSELoss函数的返回值相同。
# 证明了MSELoss默认用于计算两个输入逐元素差值平方和的均值。
print(((1-0)*(1-0)+(2-0)*(2-0)+(3-0)*(3-0))/3.)

insert image description here

3. Parameter introduction

If the three parameters of reduce, size_average, and reduction are given at the same time, the first two parameters are looked at first. If the first two parameters are None, the return value of the function is determined by the reduction parameter. If the first two parameters are not None, the return value of the function is determined by the first two parameters. In this case, the parameter that is None defaults to True. After determining the values ​​of the three parameters, calculate according to the following rules:

  • When reduce=True , if size_average=True , return the mean value of all sample losses in a batch , and the result is a scalar.Notice, for the MESLoss function, first perform an element-by-element mean operation on all sample losses in the batch, and then perform the mean operation on the obtained N values ​​to obtain the return value (assuming that the batch size is N, that is, there are N samples in the batch), in the words of the official website, it is The mean operation still operates over all the elements, and divides by N.
  • When reduce=True , if size_average=False , return the sum of all sample losses in a batch , and the result is a scalar .Notice, for the MESLoss function, first perform an element-by-element sum operation on all sample losses in the batch, and then perform a sum operation on the obtained N values ​​to obtain the return value (assuming that the batch size is N, that is, there are N samples in the batch), in the words of the official website, it is The sum operation still operates over all the elements.
  • When reduce=False , the size_average parameter is invalid , that is, whether the size_average parameter is False or True, the effect is the same. At this point, the function returns the loss of each sample in a batch , and the result is a vector .
  • The reduction parameter contains the double meaning of the reduce and size_average parameters . That is, when reduction='none', it is equivalent to reduce=False; when reduction='sum', it is equivalent to reduce=True and size_average=False; when reduction='mean', it is equivalent to reduce=True and size_average=True; this is why the reduce and size_average parameters will be deprecated in subsequent versions .

In fact, you don't have to think so carefully when using this function. The reason why the above analysis is so detailed is just to systematically analyze and explain the function to help those students who like to study deeply.If you just want to use the function quickly, you only need to set the first two parameters, namely reduce and size_average, to None, and then pass parameters to reduction; since the first two parameters of the function itself default to None, soYou only need to pass parameters to the reductionFor specific usage examples, please refer to the fourth part.

4. Examples

1. When reduction='mean', it returns the mean value of all sample losses in a batch.

import torch
import torch.nn.functional as F
input = [[[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]],

         [[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]]]
input = torch.tensor(input)

target = [[[1.,2.,3.],
           [4.,5.,6.],
           [7.,8.,9.]],

          [[11.,12.,13.],
           [14.,15.,16.],
           [17.,18.,19.]]]
target = torch.tensor(target)

loss = torch.nn.MSELoss(reduction='mean') # loss = torch.nn.MSELoss()效果相同,因为reduction参数默认为'mean'。
loss = loss(input, target)
print(loss)

# 注意,下式最后除以2是指该函数输入的批大小为2;下式中除以9是指该函数输入的批数据中每个样本的元素个数为9。
mean_result = ((1.*1. + 2.*2. + 3.*3. + 4.*4. + 5.*5. + 6.*6. + 7.*7. + 8.*8. + 9.*9.)/9 + (11.*11. + 12.*12. + 13.*13. + 14.*14. + 15.*15. + 16.*16. + 17.*17. + 18.*18. + 19.*19.)/9) / 2
print(mean_result)

insert image description here

2. When reduction='sum', it returns the sum of all sample losses in a batch.

import torch
import torch.nn.functional as F
input = [[[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]],

         [[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]]]
input = torch.tensor(input)

target = [[[1.,2.,3.],
           [4.,5.,6.],
           [7.,8.,9.]],

          [[11.,12.,13.],
           [14.,15.,16.],
           [17.,18.,19.]]]
target = torch.tensor(target)

loss = torch.nn.MSELoss(reduction='sum')
loss = loss(input, target)
print(loss)
sum_result = ((1.*1. + 2.*2. + 3.*3. + 4.*4. + 5.*5. + 6.*6. + 7.*7. + 8.*8. + 9.*9.) + (11.*11. + 12.*12. + 13.*13. + 14.*14. + 15.*15. + 16.*16. + 17.*17. + 18.*18. + 19.*19.))
print(sum_result)

insert image description here

3. When reduction='none', the loss of each sample in a batch is returned.

import torch
import torch.nn.functional as F
input = [[[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]],

         [[0.,0.,0.],
          [0.,0.,0.],
          [0.,0.,0.]]]
input = torch.tensor(input)

target = [[[1.,2.,3.],
           [4.,5.,6.],
           [7.,8.,9.]],

          [[11.,12.,13.],
           [14.,15.,16.],
           [17.,18.,19.]]]
target = torch.tensor(target)

loss = torch.nn.MSELoss(reduction='none')
loss = loss(input, target)
print(loss)

insert image description here

5. Reference link

Guess you like

Origin blog.csdn.net/qq_40968179/article/details/128260036