Low-light enhancement--paper reading ["Toward Fast, Flexible, and Robust Low-Light Image Enhancement"]


foreword

Introduce a recent paper on low-light enhancement - self-calibration lighting. The method given in the paper has achieved very good results, and it is worth learning and thinking.

Paper Title : Toward Fast, Flexible, and Robust Low-Light Image Enhancement (Towards Fast, Flexible, and Robust Low-Light Image Enhancement)

Paper information : An article published by Dalian University of Technology on CVPR Oral in April 2022.

Paper address : https://arxiv.org/abs/2204.10137

论文主要贡献总结如下:
	1.我们提出了一个自校正的共享权重照明学习模块,使各阶段的结果收敛,提高了曝光稳定性,大大减少了计算量。据我们所知,这是第一个利用在学习过程中加速弱光图像增强算法的工作。
	2.我们定义了无监督训练损失,在自校正模块的作用下约束各阶段的输出,赋予了对不同场景的适应能力。属性分析表明,SCI具有操作不敏感的适应性和模型无关的通用性,这是现有文献所没有的。
	3.我们进行了大量的实验,以证明我们的方法优于其他最先进的方法。进一步在黑暗人脸检测和夜间语义分割方面的应用,揭示了本文方法的实用价值。简而言之,SCI重新定义了基于网络的微光图像增强领域的视觉质量、计算效率和下游任务的性能的峰值点。

The picture below is a comparison with other low-light image enhancement methods, in general it is a "hexagonal warrior".

1. Basic principles

The article innovatively proposes the Self-Calibrated Illumination (self-calibration illumination) learning framework. By introducing the self-calibration module, the large reasoning cost under the cascade mechanism is greatly reduced, and the reasoning speed is accelerated.
The basic principle of the network is still based on the classic Retinex theory, let's talk about this theory.

Retinex theory

insert image description here
So for a low-light image y, it is equal to its clear image z (corresponding to its reflection component) multiplied by the amount of light x. Right now:

y=zⓧx

insert image description here

Core idea : Obtain reflective images that represent essential information. By separating the incident image, it is possible to reduce the influence of illumination factors on the image, enhance the detailed information of the image, and obtain the content representing the essential information of the image.

For the image data S(x,y) we have obtained now, if we want to get the enhanced image R(x,y), the key now is how to get I(X,Y).

Single Scale Retinex Algorithm (SSR)

For the formula: S = I × \times× R, taking the logarithm on both sides can get:
Log[R(x,y)] = Log[S(x,y)]-Log[I(x,y)] The
proponent of Retinex theory pointed out that this I(x,be obtainedby performingGaussian blur
insert image description here
As for the specific process of Gaussian blur, you can Baidu.

As for whether this I(x,y) can be accurately obtained after Gaussian blur, I personally think that there should be no accurate mathematical proof at present, and it is just an approximate processing method.
Then another method is to take advantage of the advantages of CNN, which can obtain this particularly critical illumination amount I(x,y) through data training in the form of network design. And self-calibrating lighting is one such network that learns the amount of lighting.

2. The content of the thesis

1. Network structure

insert image description here
As shown in the figure, the whole structure is divided into two parts: Self-Calibrated Module (self-calibration module) and Illumination Estimation (illumination estimation module), where the self-calibration module is an auxiliary module to reduce the computational burden of the cascade mode .

Illumination Eastimation

Instead of directly learning the mapping between pictures and brightness, the author proposes a new method for learning the amount of illumination.
Let's look at the lighting estimation module first:
insert image description here

ut: The residual error of the t-stage ------ the way to calculate the residual error can greatly reduce the amount of calculation and maintain stability, especially for exposure control, it will have a good ability.
(It feels like the idea of ​​ResNet. The role here is to learn a little amount of light at each stage in the form of a cascaded network, and finally learn the entire light amount.) Xt:
Lighting Hθ at stage t
: Lighting estimation network, and Hθ It has nothing to do with the number of stages, that is, the illumination estimation network maintains the structure and parameter sharing state at each stage

Self-Calibrated Module: The role is to make the results of each stage converge to the same state.

insert image description here
y: low-light image
Z: target image
S: self-correcting map
Kϑ: parameterized operator, ϑ parameters can be learned Vt: calibrated input
for the next stage
Weak-light inputs (i.e., the inputs of the first stage) are concatenated to indirectly explore the convergence behavior between stages, and a self-correcting map S is introduced to represent the difference between the input of each stage and the input of the first stage . The self-calibration module ensures that the outputs at different stages of the training process converge to the same state.

The basic unit of the lighting optimization process is reformulated as:
insert image description here

The introduction of the self-calibration module enables the results of different stages to quickly converge to the same state, that is, the results of the three stages coincide. But without a self-calibrating module, this phenomenon cannot be detected.

insert image description here

自己根据上面内容重新绘制了一个图,用于更好的理解。

insert image description here

2. Loss function

The paper adopts unsupervised learning to expand the capacity of the network taking into account the inaccuracy of the existing paired data. The unsupervised loss function defined as follows:
insert image description here
is divided into two parts: fidelity loss Lf and smoothing loss Ls; α and β are two balance parameters.

loss of fidelity

Lf is used to ensure the pixel-level consistency between the estimated illuminance and the input of each stage ; T is the total number of stages.
This is easy to understand. xt is the amount of light in the t stage, and the part in the brackets is the auxiliary amount v(t-1) obtained after the self-calibration module. The function of the self-calibration module is to make the results of each stage tend to be consistent. Then it is necessary to ensure that these two quantities should be very similar at each stage.
insert image description here

smoothing loss

insert image description here
N: total number of pixels
i: indicates the i-th pixel.
N (i): Indicates the adjacent pixels of i in its 5 × 5 window
Wi,j: Indicates the weight
insert image description here
My understanding is that for the amount of light obtained at each stage, the overall light and dark distribution should be smooth, not If it is locally too bright or too dark, then it is necessary to make each pixel value very similar to the surrounding pixel values.

3 Discussion

Operation-Insensitive Adaptability (operation insensitive adaptability, that is, stable performance under different simple operation settings)

insert image description here
The general idea is that SCI brightens low-light observations at different settings, showing very similar enhancement results.
The reason is that SCI not only transforms the consensus of illumination (i.e., residual learning), but also integrates the physical principle (i.e., pixel-wise division operation).

Model-Irrelevant Generality (model-independent generality, that is, it can be applied to existing works based on lighting to improve performance)

insert image description here
Our SCI is actually a generalized learning paradigm if not restricted to task-related self-correction modules, so ideally it can be directly applied to realized work.

Finally, the author did a comparative experiment:
direct learning of lighting———>>> image overexposure
learning residual between lighting and input———>>> does suppress overexposure, but the overall image quality is still not high, especially It is the grasp of details.
In contrast, the enhancement results obtained by our method not only suppress overexposure, but also enrich the image structure.
insert image description here

2. Model code (official code)

import torch
import torch.nn as nn
from loss import LossFunction

class EnhanceNetwork(nn.Module):
    def __init__(self, layers, channels):
        super(EnhanceNetwork, self).__init__()

        kernel_size = 3
        dilation = 1
        padding = int((kernel_size - 1) / 2) * dilation

        self.in_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.ReLU()
        )

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )

        self.blocks = nn.ModuleList()
        for i in range(layers):
            self.blocks.append(self.conv)

        self.out_conv = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, input):
        fea = self.in_conv(input)
        for conv in self.blocks:
            fea = fea + conv(fea)
        fea = self.out_conv(fea)

        illu = fea + input
        illu = torch.clamp(illu, 0.0001, 1)

        return illu


class CalibrateNetwork(nn.Module):
    def __init__(self, layers, channels):
        super(CalibrateNetwork, self).__init__()
        kernel_size = 3
        dilation = 1
        padding = int((kernel_size - 1) / 2) * dilation
        self.layers = layers

        self.in_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )

        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        self.blocks = nn.ModuleList()
        for i in range(layers):
            self.blocks.append(self.convs)

        self.out_conv = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, input):
        fea = self.in_conv(input)
        for conv in self.blocks:
            fea = fea + conv(fea)

        fea = self.out_conv(fea)
        delta = input - fea

        return delta



class Network(nn.Module):

    def __init__(self, stage=3):
        super(Network, self).__init__()
        self.stage = stage
        self.enhance = EnhanceNetwork(layers=1, channels=3)
        self.calibrate = CalibrateNetwork(layers=3, channels=16)
        self._criterion = LossFunction()

    def weights_init(self, m):
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

        if isinstance(m, nn.BatchNorm2d):
            m.weight.data.normal_(1., 0.02)

    def forward(self, input):

        ilist, rlist, inlist, attlist = [], [], [], []
        input_op = input
        for i in range(self.stage):
            inlist.append(input_op)
            i = self.enhance(input_op)
            r = input / i
            r = torch.clamp(r, 0, 1)
            att = self.calibrate(r)
            input_op = input + att
            ilist.append(i)
            rlist.append(r)
            attlist.append(torch.abs(att))

        return ilist, rlist, inlist, attlist

    def _loss(self, input):
        i_list, en_list, in_list, _ = self(input)
        loss = 0
        for i in range(self.stage):
            loss += self._criterion(in_list[i], i_list[i])
        return loss



class Finetunemodel(nn.Module):

    def __init__(self, weights):
        super(Finetunemodel, self).__init__()
        self.enhance = EnhanceNetwork(layers=1, channels=3)
        self._criterion = LossFunction()

        base_weights = torch.load(weights)
        pretrained_dict = base_weights
        model_dict = self.state_dict()
        pretrained_dict = {
    
    k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        self.load_state_dict(model_dict)

    def weights_init(self, m):
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

        if isinstance(m, nn.BatchNorm2d):
            m.weight.data.normal_(1., 0.02)

    def forward(self, input):
        i = self.enhance(input)
        r = input / i
        r = torch.clamp(r, 0, 1)
        return i, r

Summarize

SCI opens up a new perspective: that is, introducing an auxiliary process in the training stage to enhance the model ability of basic units.

The above content is mixed with personal understanding in many places. If there are mistakes, everyone is welcome to criticize and correct!

Guess you like

Origin blog.csdn.net/m0_46366547/article/details/128893074