LLNet模型实现——训练数据准备之抽取训练样本

1. 背景

LLNet模型通过训练高斯噪声自动编码器,实现图像增强和图像噪声去除

2. 代码实现

% LLNet: Deep Autoencoders for Low-light Image Enhancement
% 生成LLNet所需要的训练样本
% Author: HSW
% Date: 2018-05-05
%

% patchsize: 17 x 17 = (2 * ksize + 1) x (2 * ksize + 1) 
ksize = 8;
data_dir = 'C:\Users\heshiwen\Desktop\LLNet模型\png_data_set';
format   = '.png';

% 获取数据集合的全部文件
filePaths = searchRoot(data_dir, format);
fileCnt = length(filePaths);
% 训练样本的保存路径
patchsPath = 'C:\Users\heshiwen\Desktop\LLNet模型\training\image';
labelsPath = 'C:\Users\heshiwen\Desktop\LLNet模型\training\label'; 
% Gamma变换的取值范围: 均匀分布[gammaMin, gammaMax]
gammaMin = 2;
gammaMax = 5;
% Gaussian噪声的方方差范围: 均匀分布[sigmaMin, sigmaMax]
sigmaMinB = 0;
sigmaMaxB = 1;
% 每张图片抽取样本的数量
patchCnt = 2500;
% 每组配置采样的个数
patchPerConfig = 100;
% 样本集合编号
globalNum = 0;
% 是否显示图像
isDisplay = 1; 

for fileIdx = 1:fileCnt
    filePath = filePaths{fileIdx};  % 抽取样本的路径名称
    sample_img = imread(filePath);
    [m, n] = size(sample_img);
    if isDisplay 
        figure(1); 
        imshow(sample_img, []); 
        title('原图像'); 
    end 
    patchMap = zeros([m,n]); 
    % 每张图像采集patchCnt个样本
    imageNum = 0;
    while imageNum < patchCnt
        gammaVal = gammaMin + (gammaMax - gammaMin) * rand(1, 1);
        sigmaB = sigmaMinB + (sigmaMaxB - sigmaMinB) * rand(1, 1);
        sigmaVal = sqrt(sigmaB * (25.0 / 255)^2);
        % 进行Gamma变换
        sample_img_gamma = imadjust(sample_img, [0, 1], [0, 1], gammaVal);
        % 加入噪声
        sample_img_noise = imnoise(sample_img_gamma, 'gaussian', 0, sigmaVal);
        if isDisplay
            figure(2); 
            imshow(sample_img_noise, []); 
            title('Gamma+Noise'); 
        end 
        % 每组配置采样patchPerConfig
        idx = 0;
        while idx < patchPerConfig
            % 图像块中心点
            posX = randi([ksize + 1,  m - ksize],[1, 1]);
            posY = randi([ksize + 1, n - ksize], [1, 1]);
            % 抽取图像块
            patch = sample_img_noise(posX + [-ksize:ksize], posY + [-ksize:ksize]); 
            label = sample_img(posX + [-ksize:ksize], posY + [-ksize:ksize]); 
            patchMap(posX + [-ksize:ksize], posY + [-ksize:ksize]) = patchMap(posX + [-ksize:ksize], posY + [-ksize:ksize]) + 1; 
            saveName = sprintf('%07d.png', globalNum); 
            savePatchPath = fullfile(patchsPath, saveName); 
            saveLabelPath = fullfile(labelsPath, saveName); 
            imwrite(patch, savePatchPath, 'png'); 
            imwrite(label, saveLabelPath, 'png'); 
            imageNum = imageNum + 1; 
            globalNum = globalNum + 1; 
            idx       = idx + 1; 
        end
    end %while
    if isDisplay
        figure(3); 
        imshow(patchMap, []); 
        title('PositionMap'); 
    end 
end

3. 代码效果


猜你喜欢

转载自blog.csdn.net/hit1524468/article/details/80208743
今日推荐