基于SRGAN的人脸图像超分辨率

引言

SRGAN是第一个将GAN用在图像超分辨率上的模型。在这之前,超分辨率常用的损失是L1、L2这种像素损失,这使得模型倾向于学习到平均的结果,也就是给低分辨率图像增加“模糊的细节”。SRGAN引入GAN来解决这个问题。GAN可以生成“真实”的图像, 那么当“真实的图像”是清晰的图像时,也意味着GAN可以生成清晰的图像。但是,如果只用GAN损失,没有其他约束,并不能生成与低分辨率图像对应的高分辨率图像。所以,将像素损失和对抗损失相结合。此外,SRGAN还使用了感知损失,计算图像在特征空间的损失。

准备

import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import paddle
import paddle as P
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout, AdaptiveAvgPool2D, MaxPool2D, AvgPool2D

nn.initializer.set_global_initializer(nn.initializer.Normal(mean=0.0,std=0.01), nn.initializer.Constant())

加载数据

使用CelebA数据集,实现人脸图像超分辨率。
为了不OOM,切块大小为44×44(而且CelebA也只能切这么大了),与原文96×96不同。

SCALE = 4
PATH = '/path/to/data/celeba/img_align_celeba/'
DIRS = os.listdir(PATH)
PATCH_SIZE = [44, 44, 3]


def reader_patch(batchsize,scale=SCALE,patchsize=PATCH_SIZE):
    np.random.shuffle(DIRS)
    for filename in DIRS:
        LRs = np.zeros((batchsize,patchsize[2],patchsize[0],patchsize[1])).astype("float32")
        HRs = np.zeros((batchsize,patchsize[2],patchsize[0]*scale,patchsize[1]*scale)).astype("float32")
        image = Image.open(PATH+filename)
        sz = image.size
        sz_row = sz[1]//(patchsize[0]*scale)*patchsize[0]*scale
        diff_row = sz[1] - sz_row
        sz_col = sz[0]//(patchsize[1]*scale)*patchsize[1]*scale
        diff_col = sz[0] - sz_col
        row_min = np.random.randint(diff_row+1)
        col_min = np.random.randint(diff_col+1)
        HR = image.crop((col_min,row_min,col_min+sz_col,row_min+sz_row))
        LR = HR.resize((sz[0]//(patchsize[1]*scale)*patchsize[1],sz[1]//(patchsize[0]*scale)*patchsize[0]), Image.BICUBIC)
        LR = np.array(LR).astype("float32") / 255 * 2 - 1
        HR = np.array(HR).astype("float32") / 255 * 2 - 1
        for batch in range(batchsize):
            rowMin, colMin = np.random.randint(0,LR.shape[0]-patchsize[0]+1), np.random.randint(0,LR.shape[1]-patchsize[1]+1)
            LRs[batch,:,:,:] = LR[rowMin:rowMin+patchsize[0], colMin:colMin+patchsize[1],:].transpose([2,0,1])
            HRs[batch,:,:,:] = HR[scale*rowMin:scale*(rowMin+patchsize[0]), scale*colMin:scale*(colMin+patchsize[1])].transpose([2,0,1])
        yield LRs, HRs


def data_augmentation(LR, HR): #数据增强:随机翻转、旋转
    if np.random.randint(2) == 1:
        LR = LR[:,:,:,::-1]
        HR = HR[:,:,:,::-1]
    n = np.random.randint(4)
    if n == 1:
        LR = LR[:,:,::-1,:].transpose([0,1,3,2])
        HR = HR[:,:,::-1,:].transpose([0,1,3,2])
    if n == 2:
        LR = LR[:,:,::-1,::-1]
        HR = HR[:,:,::-1,::-1]
    if n == 3:
        LR = LR[:,:,:,::-1].transpose([0,1,3,2])
        HR = HR[:,:,:,::-1].transpose([0,1,3,2])
    return LR, HR


data = reader_patch(1)
for i in range(2):
    LR, HR = next(data)
    LR = LR.transpose([2,3,1,0]).reshape(PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2])
    LR = Image.fromarray(np.uint8((LR+1)/2*255))
    HR = HR.transpose([2,3,1,0]).reshape(PATCH_SIZE[0]*SCALE,PATCH_SIZE[1]*SCALE,PATCH_SIZE[2])
    HR = Image.fromarray(np.uint8((HR+1)/2*255))
    plt.subplot(1,2,1), plt.imshow(LR),plt.title('LRx'+str(SCALE))
    plt.subplot(1,2,2), plt.imshow(HR),plt.title('HR')
    plt.show()

网络结构

生成器整体结构:

这是一个残差网络,名为SRResNet。首先用一个卷积提取浅层特征,然后经过一个残差层提取深层特征,最后是一个上采样层重建出高分辨率图像。
其中残差层包括16个残差块、一个卷积和跳级连接。
上采样层有两个上采样块和一个卷积。
除了第一个卷积和上采样层中的卷积,每个卷积后面都有BN(其实,BN在SR中没有效果甚至略差,SR输入和输出有相似的空间分布,而BN白化中间的特征的方式完全破坏了原始空间的表征,因此需要部分参数来恢复这种表征,所以同样多的参数,有BN的还要拿出一部分参数做恢复,效果就差了点)。
激活函数都为PReLU,由于我不知道怎么实现PReLU,所以用ReLU代替。。。

class G(nn.Layer): # 生成器SRResNet

    def __init__(self, channel=64, num_rb=16):
        super(G, self).__init__()
        self.conv1 = nn.Conv2D(3, channel, 9, 1, 4)
        # self.prelu = nn.PReLU('all')
        self.prelu = nn.ReLU()
        self.rb_list = []
        for i in range(num_rb):
            self.rb_list += [self.add_sublayer('rb_%d' % i, RB(channel))]
        self.conv2 = nn.Conv2D(channel, channel, 3, 1, 1)
        self.bn = nn.BatchNorm2D(channel)
        self.us1 = US(channel, channel*4)
        self.us2 = US(channel, channel*4)
        self.conv3 = nn.Conv2D(channel, 3, 9, 1, 4)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.prelu(x)
        y = x
        for rb in self.rb_list:
            y = rb(y)
        y = self.conv2(y)
        y = self.bn(y)
        y = x + y
        y = self.us1(y)
        y = self.us2(y)
        y = self.conv3(y)
        return y

残差块:

这是一个经典的残差块:conv、bn、relu(prelu)、conv、bn加跳过连接。

class RB(nn.Layer): # 残差块

    def __init__(self, channel=64):
        super(RB, self).__init__()
        self.conv1 = nn.Conv2D(channel, channel, 3, 1, 1)
        self.bn1 = nn.BatchNorm2D(channel)
        # self.prelu = nn.PReLU('all')
        self.prelu = nn.ReLU()
        self.conv2 = nn.Conv2D(channel, channel, 3, 1, 1)
        self.bn2 = nn.BatchNorm2D(channel)
    
    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.prelu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        return x + y

上采样块:

包括conv、upscale_factor为2的pixelshuffle和prelu。
网络里用了两个上采样块,所以总的upscale_factor为4。

class US(nn.Layer): # 上采样块
    
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(US, self).__init__()
        self.conv = nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding)
        self.ps = nn.PixelShuffle (2)
        # self.prelu = nn.PReLU('all')
        self.prelu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.ps(x)
        x = self.prelu(x)
        return x

判别器整体结构:

这是一个经典的结构,包括一系列的conv-bn-leakyrelu和两个全连接。
第一个conv后没有bn;除了最后的激活函数为sigmoid,其他都为leakyrelu。
由于有全连接的存在,不同的输入尺寸会有不同的全连接参数数量,这里的参数数量与论文中不同。

class D(nn.Layer): # 判别器

    def __init__(self, channel=64):
        super(D, self).__init__()
        self.layer_list = []
        self.layer_list += [self.add_sublayer('conv', nn.Conv2D(3, channel, 3, 1, 1))]
        self.layer_list += [self.add_sublayer('lrelu1', nn.LeakyReLU())]
        self.layer_list += [self.add_sublayer('cna1', CNA(channel, channel, 3, 2, [1,0,1,0]))]
        self.layer_list += [self.add_sublayer('cna2', CNA(channel, channel*2))]
        self.layer_list += [self.add_sublayer('cna3', CNA(channel*2, channel*2, 3, 2, [1,0,1,0]))]
        self.layer_list += [self.add_sublayer('cna4', CNA(channel*2, channel*4))]
        self.layer_list += [self.add_sublayer('cna5', CNA(channel*4, channel*4, 3, 2, [1,0,1,0]))]
        self.layer_list += [self.add_sublayer('cna6', CNA(channel*4, channel*8))]
        self.layer_list += [self.add_sublayer('cna7', CNA(channel*8, channel*8, 3, 2, [1,0,1,0]))]
        self.layer_list += [self.add_sublayer('flatten', nn.Flatten(start_axis=1, stop_axis=3))]
        self.layer_list += [self.add_sublayer('fc1', nn.Linear(PATCH_SIZE[0]*4//16*PATCH_SIZE[1]*4//16*channel*8, channel*16))]
        self.layer_list += [self.add_sublayer('lrelu2', nn.LeakyReLU())]
        self.layer_list += [self.add_sublayer('fc1', nn.Linear(channel*16, 1))]
        self.layer_list += [self.add_sublayer('sigmoid', nn.Sigmoid())]
    
    def forward(self, x):
        for layer in self.layer_list:
            x = layer(x)
        return x

conv + norm + act:

class CNA(nn.Layer): # conv-norm-act
    
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(CNA, self).__init__()
        self.conv = nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm(out_channels)
        self.lrelu = nn.LeakyReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.lrelu(x)
        return x

预训练网络VGG19。
代码链接:
https://github.com/PaddlePaddle/PaddleClas/blob/dygraph/ppcls/modeling/architectures/vgg.py
参数下载链接:
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/VGG19_pretrained.pdparams
这里使用conv5_4后激活层的输出。

class ConvBlock(nn.Layer):
    def __init__(self, input_channels, output_channels, groups, name=None):
        super(ConvBlock, self).__init__()

        self.groups = groups
        self._conv_1 = Conv2D(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(name=name + "1_weights"),
            bias_attr=False)
        if groups == 2 or groups == 3 or groups == 4:
            self._conv_2 = Conv2D(
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "2_weights"),
                bias_attr=False)
        if groups == 3 or groups == 4:
            self._conv_3 = Conv2D(
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "3_weights"),
                bias_attr=False)
        if groups == 4:
            self._conv_4 = Conv2D(
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "4_weights"),
                bias_attr=False)

        self._pool = MaxPool2D(kernel_size=2, stride=2, padding=0)

    def forward(self, inputs):
        x = self._conv_1(inputs)
        x = F.relu(x)
        if self.groups == 2 or self.groups == 3 or self.groups == 4:
            x = self._conv_2(x)
            x = F.relu(x)
        if self.groups == 3 or self.groups == 4:
            x = self._conv_3(x)
            x = F.relu(x)
        if self.groups == 4:
            x = self._conv_4(x)
            x = F.relu(x)
        y = x
        x = self._pool(x)
        return x, y


class VGGNet(nn.Layer):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.groups = [2, 2, 4, 4, 4]
        self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")
        self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")
        self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")
        self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
        self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")

    def forward(self, inputs):
        x, y = self._conv_block_1(inputs)
        x, y = self._conv_block_2(x)
        x, y = self._conv_block_3(x)
        x, y = self._conv_block_4(x)
        _, y = self._conv_block_5(x)
        return y
vgg19 = VGGNet()
vgg19.set_state_dict(P.load('/home/aistudio/work/vgg19_ww.pdparams'))
vgg19.eval()

辅助函数

在训练迭代中显示图像,以观察效果。

def show_image(srresnet=None, srgan=None, path=None):
    if srresnet == None:
        srresnet = G()
    srresnet.eval()
    if srgan == None:
        srgan = G() 
    srgan.eval()
    fig = plt.figure(figsize=(25, 25))
    gs = plt.GridSpec(1, 4)
    gs.update(wspace=0.1, hspace=0.1)
    if path == None:
        image = Image.open(PATH+DIRS[np.random.randint(len(DIRS))])
    else:
        image = Image.open(path)
    image = image.crop([0,0,image.size[0]//SCALE*SCALE,image.size[1]//SCALE*SCALE])
    # image = image.crop([0,0,40,40])
    LR0 = image.resize((image.size[0]//SCALE,image.size[1]//SCALE),Image.BICUBIC)
    LR = np.array(LR0).astype('float32').reshape([LR0.size[1],LR0.size[0],3,1]).transpose([3,2,0,1]) / 255 * 2 - 1
    LSR_srresnet = srresnet(P.to_tensor(LR)).numpy()
    LSR_srresnet = LSR_srresnet.reshape([3,LR0.size[1]*SCALE,LR0.size[0]*SCALE]).transpose([1,2,0])
    # LSR_srresnet = Image.fromarray(np.uint8((LSR_srresnet+1)/2*255)) ### 亮斑的罪魁祸首
    LSR_srresnet = (LSR_srresnet+1)/2
    LSR_srgan = srgan(P.to_tensor(LR)).numpy()
    print(np.max(LSR_srgan), np.min(LSR_srgan))
    LSR_srgan = LSR_srgan.reshape([3,LR0.size[1]*SCALE,LR0.size[0]*SCALE]).transpose([1,2,0])
    # LSR_srgan = Image.fromarray(np.uint8((LSR_srgan+1)/2*255)) ### 亮斑的罪魁祸首
    LSR_srgan = (LSR_srgan+1)/2
    ax = plt.subplot(gs[0])
    plt.imshow(LR0)
    plt.title('LR')
    ax = plt.subplot(gs[1])
    plt.imshow(LSR_srresnet)
    plt.title('SRResNet')
    ax = plt.subplot(gs[2])
    plt.imshow(LSR_srgan)
    plt.title('SRGAN')
    ax = plt.subplot(gs[3])
    plt.imshow(image)
    plt.title('HR')
    plt.show()


show_image()

 训练

为了与SRGAN作比较,同时训练一个SRResNet,也就是只使用了生成器,并只用L2损失来训练的网络。
SRGAN生成器的损失 = 图像L2损失 + λ1×感知损失 + λ2×对抗损失, 其中λ1=1e-2, λ2=1e-2。
SRResNet和SRGAN的生成器相同初始化。
由于Celeba比DIV2K图像数量多很多,epoch可以相对少一些。

def srresnet_trainer(lr, hr, srresnet, optimizer_srresnet):
    sr = srresnet(lr)
    loss = P.mean((sr-hr)**2)
    srresnet.clear_gradients()
    loss.backward()
    optimizer_srresnet.minimize(loss)


def srgan_trainer(lr, hr, srgan_g, srgan_d, vgg, optimizer_srgan_g, optimizer_srgan_d, λ1=1e-2, λ2=1e-2):
    sr = srgan_g(lr)
    f = vgg(P.concat([sr,hr],axis=0))
    loss_content = P.mean((sr-hr)**2) + λ1*P.mean((f[:f.shape[0]//2,:,:,:]-f[f.shape[0]//2:,:,:,:])**2)
    d = srgan_d(P.concat([sr,hr],axis=0))
    loss_adversarial_g = P.mean(-P.log(d[:d.shape[0]//2,:]+1e-8))
    loss_adversarial_d = (P.mean(-P.log(d[d.shape[0]//2:,:]+1e-8)) + P.mean(-P.log(1-d[:d.shape[0]//2,:]+1e-8))) / 2
    loss_g = loss_content + λ2*loss_adversarial_g
    vgg.clear_gradients()
    srgan_g.clear_gradients()
    srgan_d.clear_gradients()
    loss_g.backward(retain_graph=True)
    loss_adversarial_d.backward()
    optimizer_srgan_g.minimize(loss_g)
    optimizer_srgan_d.minimize(loss_adversarial_d)


def train(epoch_num=200,  load_model=False, batchsize=1, model_path = './output/'):
    srresnet = G()
    srgan_g = G()
    srgan_g.set_state_dict(srresnet.state_dict())
    srgan_d = D()
    srgan_d.train()
    optimizer_srresnet = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srresnet.parameters())
    optimizer_srgan_g = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srgan_g.parameters())
    optimizer_srgan_d = P.optimizer.Adam(learning_rate=1e-4, beta1=0.9, parameters=srgan_d.parameters())
    if load_model == True:
        srresnet.set_state_dict(P.load(model_path+'srresnet.pdparams'))
        srgan_g.set_state_dict(P.load(model_path+'srgan_g.pdparams'))
        srgan_d.set_state_dict(P.load(model_path+'srgan_d.pdparams'))
        srresnet.set_state_dict(P.load(model_path+'备用srresnet.pdparams'))
        srgan_g.set_state_dict(P.load(model_path+'备用srgan_g.pdparams'))
        srgan_d.set_state_dict(P.load(model_path+'备用srgan_d.pdparams'))

    iteration_num = 0
    for epoch in range(epoch_num):
        reader = reader_patch(batchsize)
        for iteration in range(len(DIRS)):
            srresnet.train()
            srgan_g.train()
            iteration_num += 1             
            LR, HR = next(reader)
            LR, HR = data_augmentation(LR, HR)
            LR = P.to_tensor(LR)
            HR = P.to_tensor(HR)
            srresnet_trainer(LR, HR, srresnet, optimizer_srresnet)
            srgan_trainer(LR, HR, srgan_g, srgan_d, vgg19, optimizer_srgan_g, optimizer_srgan_d)

            if(iteration_num % 100 == 0):
                print('Epoch: ', epoch, ', Iteration: ', iteration_num)            
                P.save(srresnet.state_dict(), model_path+'srresnet.pdparams')
                P.save(srgan_g.state_dict(), model_path+'srgan_g.pdparams')
                P.save(srgan_d.state_dict(), model_path+'srgan_d.pdparams')
                P.save(srresnet.state_dict(), model_path+'备用srresnet.pdparams')
                P.save(srgan_g.state_dict(), model_path+'备用srgan_g.pdparams')
                P.save(srgan_d.state_dict(), model_path+'备用srgan_d.pdparams')
                show_image(srresnet, srgan_g)  


# train(epoch_num=1,  load_model=False, batchsize=16)
# train(epoch_num=998,  load_model=True, batchsize=16)

测试

可以看到图中有一些斑点,根据我的猜测,这是训练不充分导致的,总体上SRGAN的斑点更多,说明它比SRResNet需要更多训练,也就是它的上限更高。 老天爷,我之前竟然装模作样瞎分析一番,尴了个大尬。。。不删了,作为我成长的见证。。。出现斑点的原因其实是用了Image.fromarray(np.uint8())!不过说训练不充分也有道理,训练充分的话就不会超出范围,也就没这个幺蛾子啦。。
相对SRResNet来说,SRGAN不那么平滑,但是有些细节并不准确,更像是噪声,而且有时会出现奇怪的东西,例如额头上的亮光。

srresnet = G()
srgan_g = G()
model_path = './output/'
srresnet.set_state_dict(P.load(model_path+'srresnet.pdparams'))
srgan_g.set_state_dict(P.load(model_path+'srgan_g.pdparams'))
show_image(srresnet, srgan_g)   

猜你喜欢

转载自blog.csdn.net/qq_39312146/article/details/134608239