[Loss function: 2] Charbonnier Loss, SSIM Loss (with Pytorch implementation)

written in front

When introducing each function below, 2 variables are involved, and their meanings are as follows: Suppose the network input is x, and the output is y ‾ \overline{\text{y}}y=f(x), the real label of x is y, where: , ,
insert image description hereN insert image description herein insert image description here
the above definition usually represents the number of samples contained in a batch, because we usually send it to the network training batch by batch during network training, The loss is calculated once per batch, followed by parameter updates.

1. Charbonnier loss

Reference article link: http://xxx.itp.ac.cn/pdf/1710.01992
insert image description here
The reference article is suitable for image super-resolution tasks. For ordinary supervised tasks, Charbonnier Loss can be defined as follows:
insert image description here

in,
insert image description here

Mainly look at the curve of Charbonnier Loss in the range of (-1,1), we know that there is an indirect point y-y_=0 in L1 loss (see https://blog.csdn.net/qq_43665602/article/details/127037761) , and Charbonnier Loss solves the defect of L1 by introducing a constant epslion, and the curve can also be guided where y-y_ is close to 0. Outside this interval, the function curve approximates the L1 loss, which is less sensitive to outliers than the L2 loss, and avoids excessively magnifying the error.
insert image description here

1) Code implementation

# 4.Charbonnier Loss
class CharbonnierLoss(nn.Module):
    def __init__(self,epsilon=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.epsilon2=epsilon*epsilon

    def forward(self,x):
        value=torch.sqrt(torch.pow(x,2)+self.epsilon2)
        return torch.mean(value)
creation=CharbonnierLoss()
loss=creation(y_-y)
print(loss)
tensor([[6., 6., 6., 8., 3., 6., 6., 7., 0., 5.],
        [6., 9., 7., 1., 5., 5., 6., 2., 7., 0.],
        [0., 4., 4., 6., 9., 1., 1., 4., 6., 0.],
        [8., 5., 8., 1., 7., 5., 9., 1., 4., 7.]])
tensor([[9., 4., 9., 7., 5., 5., 5., 4., 9., 6.],
        [5., 4., 8., 8., 3., 2., 7., 4., 2., 8.],
        [2., 1., 5., 3., 1., 1., 3., 9., 5., 9.],
        [8., 8., 1., 0., 1., 5., 9., 8., 9., 0.]])
-------------------------
tensor(3.2751)

2. SSIM loss

Before officially introducing the SSIM loss, we need to know what SSIM is and how to calculate SSIM?

1. Structural similarity (SSIM: Structural Similarity)

"Image Quality Assessment: From Error Visibility to Structural Similarity" proposes the use of structural similarity metrics for image quality assessment tasks, and introduces the origin of the development of SSIM in detail. Below we introduce the entire definition and calculation process of SSIM with the original text:
the total similarity measure is defined as:
insert image description here

Among them, l(x, y), c(x, y) and s(x, y) respectively represent brightness comparison, contrast comparison and structure comparison, and the three parts are relatively independent, that is, the change of any one of them does not affect other components, And their respective definitions need to meet the following three conditions:

  • Symmetry: S(x,y)=S(y,x);
  • Boundedness: S(x,y)<=1;
  • The uniqueness of the maximum value: if and only if x=y (in discrete representation, x=y means that their corresponding elements are equal, that is, x i =y i ), S( x , y )=1;

1) Luminance compare:
defined as follows:
insert image description here

where u x and u y have similar representations, compute the average intensity of the input signal:
insert image description here

C 1 is a constant used to avoid instability when u x 2 +u y 2 is close to 0, and its definition is as follows (the definitions for C 2 and C 3 are similar):
insert image description here

Where K 1 is a constant much smaller than 1, and L represents the dynamic range of pixel values ​​(for example, 8-bit grayscale image L is 255). Obviously, the
definition of l(x,y) satisfies three conditions such as symmetry, And l(x, y) conforms to Weber's law, and its brightness change ΔI is proportional to the background brightness I, that is, ΔI/I=C, and C is a constant. Use R to represent the magnitude of the change in brightness relative to the background brightness, and express the distorted signal as u y =u x (1+R), which is further obtained (assuming that C 1 is negligible compared to u x 2
insert image description here
): 2) Contrast comparison (contrast compare):
use the mean square error of the input signal intensity to represent the contrast, mean square error calculation:
insert image description here
c(x,y) is defined as follows:
insert image description here
where C 2 =(K 2 L) 2 , K 2 is much smaller than 1, and both are constants.
3) Structure compare:
the definition is as follows:
insert image description here
where it insert image description hererepresents the mean square error corresponding to each, and:
insert image description here
the calculation formula of SSIM can be obtained by combining l(x,y), c(x,y) and s(x,y) :
insert image description here
The parameters α>0, β>0, and γ>0 are used to adjust the relative importance of the three parts, and they are all set to 1 for simplicity, and C 3 = C 2 , so:
insert image description here

2. Mean Structural Similarity (Mean SSIM)

The SSIM calculation formula introduced above can only calculate the structural similarity of local regions in the image. There may be obvious differences in the mean, variance and signal distortion degree of different regions in an entire image, so we cannot use the local calculation formula to measure For global similarity, the author proposes a solution: MSSIM (Mean SSIM), which divides the image into multiple patches, calculates the local structural similarity of each patch, and then calculates their average as a global measure. At this time: Among
insert image description here
them w i represents the weight of each pixel, and:
insert image description here
the average calculation in the local calculation formula mentioned above is a special case when w i
insert image description here
=1/N. In addition: assuming that there are M patches in total, the global metric is:
insert image description here

3. Code implementation

A few points need to be clarified before code writing:
1) The meaning of Patch
The convolution operation itself is a "weighted summation" process. Patch is similar to the sliding window in the convolution operation, and the weight is specified by the elements of the convolution kernel.
2) Calculation of local mean, mean square error, and covariance
(1) Local mean can be completed by one convolution;
(2) Mean square error: square root of variance; variance:

  • First multiply the input image by x, and then convolve it to get E(x 2 );
  • Then convolve the input image x to get E(x);
  • Finally, j can calculate E(x 2 )-E(x) 2 ;

insert image description here
(3) Similar covariance calculation process can be obtained:

  • First multiply the input image x, y, and then convolve it to get E(xy);
  • Then convolute the input images x and y to obtain E(x), E(y);
  • Finally, calculate E(xy)-E(x)E(y);

insert image description here
The following is the final code implementation (the input image pixel values ​​need to be normalized in advance when using the following code):

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from PIL import Image
from torchvision import transforms
from math import exp
# 5.SSIM loss
# 生成一位高斯权重,并将其归一化
def gaussian(window_size,sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss/torch.sum(gauss)  # 归一化


# x=gaussian(3,1.5)
# # print(x)
# x=x.unsqueeze(1)
# print(x.shape) #torch.Size([3,1])
# print(x.t().unsqueeze(0).unsqueeze(0).shape) # torch.Size([1,1,1, 3])

# 生成滑动窗口权重,创建高斯核:通过一维高斯向量进行矩阵乘法得到
def create_window(window_size,channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)  # window_size,1
    # mm:矩阵乘法 t:转置矩阵 ->1,1,window_size,_window_size
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    # expand:扩大张量的尺寸,比如3,1->3,4则意味将输入张量的列复制四份,
    # 1,1,window_size,_window_size->channel,1,window_size,_window_size
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


# 构造损失函数用于网络训练或者普通计算SSIM值
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


# 普通计算SSIM
def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

4. Test case

Here are two points:

  • Only calculate the SSIM value between images;
  • Construct a loss function through SSIM for network training;

Input image:
insert image description here

insert image description here
Data read in:

# 读取数据
haze_path='./01_hazy.png'
gt_path='./01_GT.png'


def read_img(path):
    img=Image.open(path)
    return transforms.ToTensor()(img)  # 数据转为张量并进行数值归一化


haze_img,gt_img=read_img(haze_path),read_img(gt_path)
# print(type(haze_img))
haze_img,gt_img=torch.unsqueeze(haze_img,0),torch.unsqueeze(haze_img,0)

1) Calculate the SSIM value

# 1)计算SSIM值
# 方式1
ssim_value=ssim(haze_img,gt_img)
# 方式2
ssim_loss=SSIM()(haze_img,gt_img)
print(ssim_value)
print(ssim_loss)
tensor(1.)
tensor(1.)

2) The construction loss
SSIM represents the structural similarity between the input data. The closer the calculation result is to 1, the higher the structural similarity between the two is. Therefore, when optimizing as a loss, the optimizer needs to accept ssim_loss Negative value (the optimizer looks for the minimum value by default, so it needs to be converted to a negative value, which is equivalent to the original value to find the maximum value) .

# 2)构造损失
haze_img,gt_img=Variable(haze_img,requires_grad=True),Variable(gt_img,requires_grad=False)


ssim_loss=SSIM()
optimizer=torch.optim.Adam([haze_img],lr=0.01)

# 初始化变量
ssim_value=ssim(haze_img,gt_img)
train_steps=0
for epoch in range(10):
    print("epoch:",epoch)
    # 训练过程
    while ssim_value<0.95:
        print("train times:", train_steps)
        ssim_value=-ssim_loss(haze_img,gt_img)
        optimizer.zero_grad()
        ssim_value.backward()
        optimizer.step()
        train_steps+=1
epoch: 0
epoch: 1
epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9

The output results are easy to explain, because the SSIM of the two images here is 1, and the inner loop is not executed. The main purpose is to show you a code framework of SSIM Loss.

reference:

1)https://blog.csdn.net/qq_35914625/article/details/113789903
2)https://github.com/Po-Hsun-Su/pytorch-ssim

Guess you like

Origin blog.csdn.net/qq_43665602/article/details/127041832