Practical Deep Raw Image Denoising on Mobile Devices(huber linear regression,noise level est)

Practical Deep Raw Image Denoising on Mobile Devices(huber linear regression)

code:https://github.com/MegEngine/PMRID
其他相关博客:https://blog.csdn.net/zjy_snow/article/details/124385456

1.DNG 文件处理 pipeline

关于DNG文件的处理,主要查看 SIDD 的仓库 simple-camera-pipeline.

打印bayer pattern格式
pattern = “”.join([chr(short_raw.color_desc[i]) for i in short_raw.raw_pattern.flatten()])

或者
cfa_pattern_str = “”.join([“RGB”[i] for i in cfa_pattern])

下面demo_single.py输入dng raw后,可以得到各个阶段的输出结果。

# (demo_single.py)
import glob
import os
import cv2
import numpy as np

from python.pipeline import run_pipeline_v2
from python.pipeline_utils import get_visible_raw_image, get_metadata

params = {
    
    
    'input_stage': 'raw',  # options: 'raw', 'normal', 'white_balance', 'demosaic', 'xyz', 'srgb', 'gamma', 'tone'
    'output_stage': 'tone',  # options: 'normal', 'white_balance', 'demosaic', 'xyz', 'srgb', 'gamma', 'tone'
    'save_as': 'png',  # options: 'jpg', 'png', 'tif', etc.
    'demosaic_type': 'EA',
    'save_dtype': np.uint8
}

image_path = r'D:\dataset\pratical_raw\reno10x_noise\gray_scale_chart\RAW_2020_02_20_13_11_10_787\DNG_2020_02_20_13_11_10_787_1.dng'


# raw image data
raw_image = get_visible_raw_image(image_path)
print('raw data:', raw_image.shape, raw_image.dtype, raw_image.max(), raw_image.min())
# metadata
metadata = get_metadata(image_path)
print('meta info : ', metadata)
# modify WB here
#metadata['as_shot_neutral'] = [1., 1., 1.]

# render
output_image = run_pipeline_v2(image_path, params)
#
# save
output_image_path = image_path.replace('.dng', '_{}.'.format(params['output_stage']) + params['save_as'])
max_val = 2 ** 16 if params['save_dtype'] == np.uint16 else 255
output_image = (output_image[..., ::-1] * max_val).astype(params['save_dtype'])
if params['save_as'] == 'jpg':
    cv2.imwrite(output_image_path, output_image, [cv2.IMWRITE_JPEG_QUALITY, 100])
else:
    cv2.imwrite(output_image_path, output_image)

2.noise estimation

2.1noise model

原论文3.1中给出了简洁明了又清晰的噪声模型:
主要包括高斯噪声和泊松噪声。在这里插入图片描述

在这里插入图片描述

那么通过raw图的均值和方差可以得到参数k和sigma

2.2dataset

code and dataset : https://github.com/MegEngine/PMRID

  1. reno dataset
    reno dataset 包含 [100,200,400,800,1200,1600,2400,3200,4000,4800,5600,6400]iso下的,每个iso下拍摄连续 64个raw图(注意raw图是没有经过blc的)

图像是一个会卡,光线不均匀,亮度不一
在这里插入图片描述

采集图像,iso变化,iso最大的时候曝光时间 设置为 固定的曝光时间。然后调节灯光亮度。

2.3noise estimation

以reno10x raw图为例,复现文章fig6
在这里插入图片描述

1)得到 6400 iso下的 k-sigma2图像:
在这里插入图片描述

import glob
import os
import numpy as np
from matplotlib import pyplot as plt
from scipy import optimize

from python.pipeline_utils import get_visible_raw_image, get_metadata


def get_rggb_mean_var(raw):
    """
    :param raw: bayer raw
    :return:  4个通道的均值和方差
    """
    rggb = np.dstack((raw[0::2,0::2], raw[0::2,1::2], raw[1::2,0::2], raw[1::2,1::2]))
    average = np.mean(rggb,axis=(0, 1))
    var = np.var(rggb,axis=(0, 1))
    return np.hstack((average.reshape(1, -1), var.reshape(1, -1)))

if __name__ == "__main__":

    dir = r'D:\dataset\pratical_raw\reno10x_noise\gray_scale_chart\RAW_2020_02_20_13_11_10_787'

    files = glob.glob(os.path.join(dir, '*.dng'))

    raws = []
    noise_level = []
    rggb_mean_var = []
    for file in files:
        image_path = file # r'D:\dataset\pratical_raw\reno10x_noise\gray_scale_chart\RAW_2020_02_20_13_06_43_108\DNG_2020_02_20_13_06_43_108_1.dng'
        raw_image = get_visible_raw_image(image_path) - 64
        # metadata
        metadata = get_metadata(image_path)
        #print(' raw data:', raw_image.shape, raw_image.dtype, raw_image.max(), raw_image.min())
        raws.append(raw_image)

        noise = [tem[0] for tem in metadata['noise_profile']]
        noise_level.append(noise)
        print(get_rggb_mean_var(raw_image).reshape(-1))
        rggb_mean_var.append(get_rggb_mean_var(raw_image).reshape(-1))

    noise_level = np.dstack(noise_level)
    noise_level = np.mean(noise_level, axis=-1)
    print('noise_level:', noise_level)

    rggb_mean_var = np.dstack(rggb_mean_var)
    rggb_mean_var = np.mean(rggb_mean_var.reshape(8, -1), axis=-1)
    print('rggb_mean_var:', rggb_mean_var)

    raws = np.dstack(raws)
    print(raws.dtype, raws.shape)
    raw_mean = np.mean(raws, axis=-1).astype(np.uint16)

    mm = np.arange(metadata['white_level'][0]+1)
    uu = []
    for i in range(metadata['white_level'][0]+1):
        data = raws[raw_mean == i]
        if data is None or len(data) == 0:
            uu.append(-1)
        else:
            #print(data.shape, i, data.min(), data.max())
            uu.append((data.var())) # 这里如果筛选出 异常的数据,计算会更精确

    uu = np.array(uu)
    print(mm, uu, mm.dtype, mm.shape, uu.dtype, uu.shape)

    mask = np.logical_and(uu >= 0, mm < 350)
    mm = mm[mask]
    uu = uu[mask]
    m_u = np.hstack((mm.reshape(-1, 1), uu.reshape(-1, 1))).astype(np.float32)

    # method1
    z1 = np.polyfit(mm[uu>=0], uu[uu>=0], 1)  # 用3次多项式拟合,输出系数从高到0
    p1 = np.poly1d(z1)  # 使用系数合成多项式
    r_y = p1(mm[uu > 0])
    print('z1:',z1)
    # method2
    def huber_loss(theta, x, y,  delta=0.8):
        diff = abs(y - (theta[1] + theta[0] * x))
        return ((diff < delta) * diff ** 2 / 2 + (diff >= delta) * delta * (diff - delta / 2)).sum()
    z2 = optimize.fmin(huber_loss, x0=(0, 0), args=(mm[uu>0], uu[uu>0]), disp=False)
    print('z2', z2)
    p1 = np.poly1d(z2)  # 使用系数合成多项式
    r_y2 = p1(mm[uu >= 0])

    plt.figure()
    plt.plot(mm[uu >= 0][::3], uu[uu >= 0][::3], 'r+')
    plt.plot(mm[uu > 0], r_y, 'k-')
    plt.plot(mm[uu > 0], r_y2, 'g-')
    plt.show()
    save_dir = r'D:\dataset\pratical_raw\reno10x_noise\gray_scale_chart'
    np.savetxt(os.path.join(save_dir, dir[-27:]+'_k1_sigma2.txt'), np.round(z2, 5), fmt='%.5f')
    np.savetxt(os.path.join(save_dir, dir[-27:] + '_m_var.txt'), np.round(m_u, 2), fmt='%.2f')
    np.savetxt(os.path.join(save_dir, dir[-27:] + '_noise.txt'), np.round(noise_level.reshape(-1), 9), fmt='%.9f')
    np.savetxt(os.path.join(save_dir, dir[-27:] + '_rggb_mean_var.txt'), np.round(rggb_mean_var.reshape(-1), 2), fmt='%.2f')
    print(dir)

2)各个iso下的 k, sigma2:
在这里插入图片描述

基本和文章fig6一致。

import glob
import os
import numpy as np
from matplotlib import pyplot as plt
from scipy import optimize

if __name__ == "__main__":
    iso = [100,200,400,800,1200,1600,2400,3200,4000,4800,5600,6400]

    dir = r'D:\dataset\pratical_raw\reno10x_noise\gray_scale_chart'

    k1_sigma2s = []
    files = glob.glob(os.path.join(dir, '*k1_sigma2.txt'))
    for file in files:
        k1_sigma2 = np.loadtxt(file)
        print(k1_sigma2)
        k1_sigma2s.append(k1_sigma2)
    k1_sigma2s = np.array(k1_sigma2s)
    k = k1_sigma2s[..., 0]
    sigma2 = k1_sigma2s[..., 1]
    # print('norm k:', k / k[0])

    plt.figure()
    plt.subplot(121)
    plt.plot(iso, k, 'r+')
    plt.subplot(122)
    plt.plot(iso, sigma2, 'g+')
    plt.show()



    # 各iso下图像均值变化不大,方差随着iso增大而增大。
    rggb_mean_vars = []
    files = glob.glob(os.path.join(dir, '*rggb_mean_var.txt'))
    for file in files:
        rggb_mean_var = np.loadtxt(file)
        print(rggb_mean_var)
        rggb_mean_vars.append(rggb_mean_var)
    rggb_mean_vars = np.array(rggb_mean_vars)
    print('rggb ratio:\n', rggb_mean_vars / rggb_mean_vars[0])

    # 打印DNG meta记录的 noise level
    print('\n\n noise level in meta:')
    noises = []
    files = glob.glob(os.path.join(dir, '*noise.txt'))
    for file in files:
        noise = np.loadtxt(file)
        print(noise)
        noises.append(noise[:2])
    noises = np.array(noises).reshape(-1, 2)

    plt.figure()
    plt.subplot(121)
    plt.plot(iso, noises[:,0]*(1023-64), 'r+')
    plt.subplot(122)
    plt.plot(iso, noises[:,1]*(959*959), 'g+')
    plt.show()

3.论文中的测试集

拍了五个场景,每个场景两种光源调节bright和dark.
每种光源条件又用了5中iso+expo time.
在这里插入图片描述

import glob
import os
import numpy as np
from matplotlib import pyplot as plt



'''
process PMRID dataset
'''

if __name__ == "__main__":
    dir_ori = r'D:\dataset\pratical_raw\PMRID\Scene4\dark'
    dirs = glob.glob(os.path.join(dir_ori, 'RAW*'))
    print(dirs)
    for dir in dirs:
        files = glob.glob(os.path.join(dir, "*.raw"))

        for file in files:
            print(file)

            data = np.fromfile(file, dtype=np.uint16)
            data = data.reshape(3000, 4000)

            rgb = np.dstack((data[1::2,1::2], data[1::2,0::2], data[0::2,0::2]))
            plt.figure()
            plt.imshow(rgb / rgb.max())
            plt.show()

4.论文总结

  1. 提出根据raw图泊松高斯噪声估计模型,即各种iso下的噪声参数
  2. 根据噪声参数可以生成泛化能力强的噪声图片
  3. 提出k-sigma转换处理 input 和 output, 可以使网络学习 iso-independent space, 因此不需要扩大网络模型,就可以训练一个能处理各种iso噪声的轻量化模型,且效果很好。

在这里插入图片描述

5.net

#!/usr/bin/env python3
import torch
import torch.nn as nn
from collections import OrderedDict

import numpy as np


def Conv2D(
        in_channels: int, out_channels: int,
        kernel_size: int, stride: int, padding: int,
        is_seperable: bool = False, has_relu: bool = False,
):
    modules = OrderedDict()

    if is_seperable:
        modules['depthwise'] = nn.Conv2d(
            in_channels, in_channels, kernel_size, stride, padding,
            groups=in_channels, bias=False,
        )
        modules['pointwise'] = nn.Conv2d(
            in_channels, out_channels,
            kernel_size=1, stride=1, padding=0, bias=True,
        )
    else:
        modules['conv'] = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding,
            bias=True,
        )
    if has_relu:
        modules['relu'] = nn.ReLU()

    return nn.Sequential(modules)


class EncoderBlock(nn.Module):

    def __init__(self, in_channels: int, mid_channels: int, out_channels: int, stride: int = 1):
        super().__init__()

        self.conv1 = Conv2D(in_channels, mid_channels, kernel_size=5, stride=stride, padding=2, is_seperable=True, has_relu=True)
        self.conv2 = Conv2D(mid_channels, out_channels, kernel_size=5, stride=1, padding=2, is_seperable=True, has_relu=False)

        self.proj = (
            nn.Identity()
            if stride == 1 and in_channels == out_channels else
            Conv2D(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, is_seperable=True, has_relu=False)
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        proj = self.proj(x)

        x = self.conv1(x)
        x = self.conv2(x)

        x = x + proj
        return self.relu(x)


def EncoderStage(in_channels: int, out_channels: int, num_blocks: int):

    blocks = [
        EncoderBlock(
            in_channels=in_channels,
            mid_channels=out_channels//4,
            out_channels=out_channels,
            stride=2,
        )
    ]
    for _ in range(num_blocks-1):
        blocks.append(
            EncoderBlock(
                in_channels=out_channels,
                mid_channels=out_channels//4,
                out_channels=out_channels,
                stride=1,
            )
        )

    return nn.Sequential(*blocks)


class DecoderBlock(nn.Module):

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()

        padding = kernel_size // 2
        self.conv0 = Conv2D(
            in_channels, out_channels, kernel_size=kernel_size, padding=padding,
            stride=1, is_seperable=True, has_relu=True,
        )
        self.conv1 = Conv2D(
            out_channels, out_channels, kernel_size=kernel_size, padding=padding,
            stride=1, is_seperable=True, has_relu=False,
        )

    def forward(self, x):
        inp = x
        x = self.conv0(x)
        x = self.conv1(x)
        x = x + inp
        return x


class DecoderStage(nn.Module):

    def __init__(self, in_channels: int, skip_in_channels: int, out_channels: int):
        super().__init__()

        self.decode_conv = DecoderBlock(in_channels, in_channels, kernel_size=3)
        self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.proj_conv = Conv2D(skip_in_channels, out_channels, kernel_size=3, stride=1, padding=1, is_seperable=True, has_relu=True)
        # M.init.msra_normal_(self.upsample.weight, mode='fan_in', nonlinearity='linear')

    def forward(self, inputs):
        inp, skip = inputs

        x = self.decode_conv(inp)
        x = self.upsample(x)
        y = self.proj_conv(skip)
        return x + y


class Network(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv0 = Conv2D(in_channels=4, out_channels=16, kernel_size=3, padding=1, stride=1, is_seperable=False, has_relu=True)
        self.enc1 = EncoderStage(in_channels=16, out_channels=64, num_blocks=2)
        self.enc2 = EncoderStage(in_channels=64, out_channels=128, num_blocks=2)
        self.enc3 = EncoderStage(in_channels=128, out_channels=256, num_blocks=4)
        self.enc4 = EncoderStage(in_channels=256, out_channels=512, num_blocks=4)

        self.encdec = Conv2D(in_channels=512, out_channels=64, kernel_size=3, padding=1, stride=1, is_seperable=True, has_relu=True)
        self.dec1 = DecoderStage(in_channels=64, skip_in_channels=256, out_channels=64)
        self.dec2 = DecoderStage(in_channels=64, skip_in_channels=128, out_channels=32)
        self.dec3 = DecoderStage(in_channels=32, skip_in_channels=64, out_channels=32)
        self.dec4 = DecoderStage(in_channels=32, skip_in_channels=16, out_channels=16)

        self.out0 = DecoderBlock(in_channels=16, out_channels=16, kernel_size=3)
        self.out1 = Conv2D(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=1, is_seperable=False, has_relu=False)

    def forward(self, inp):

        conv0 = self.conv0(inp)
        conv1 = self.enc1(conv0)
        conv2 = self.enc2(conv1)
        conv3 = self.enc3(conv2)
        conv4 = self.enc4(conv3)

        conv5 = self.encdec(conv4)

        up3 = self.dec1((conv5, conv3))
        up2 = self.dec2((up3, conv2))
        up1 = self.dec3((up2, conv1))
        x = self.dec4((up1, conv0))

        x = self.out0(x)
        x = self.out1(x)

        pred = inp + x
        return pred


if __name__ == "__main__":
    net = Network()
    # img = mge.tensor(np.random.randn(1, 4, 64, 64).astype(np.float32))
    img = torch.randn(1, 4, 64, 64, device=torch.device('cpu'), dtype=torch.float32)
    out = net(img)

    import IPython; IPython.embed()

# vim: ts=4 sw=4 sts=4 expandtab

猜你喜欢

转载自blog.csdn.net/tywwwww/article/details/131284403
EST