常用数据增强系神奇gan-CycleGan

常用数据增强系神奇gan-CycleGan

学习前言

写了一天代码,累加懒。本来今天想解读下下yolov4的讲解的,感觉要打太多,就有空写吧。昨天突然产生出灵感,讲讲一些论文里面常提出来的用gan来数据增强吧(改进思路),cycleGan,它的思想也是我也挺喜欢。

github

https://github.com/yanjingke/cyclegan

那么什么是CycleGan?

许多名画造假者费尽毕生的心血,试图模仿出艺术名家的风格。CycleGAN就可以初步实现这个神奇的功能。这个功能就是风格迁移。
在这里插入图片描述
在这里插入图片描述

CircleGan的安装

因为由于CycleGAN要用到InstanceNormalization,这个函数在普通的keras内不存在,所以要安装一个新的库。

首先去github上下载https://github.com/keras-team/keras-contrib库,下载完后解压。cmd 里面进入目录python setup.py install。github下载慢的话,可以把项目clone到码云里面,下载。具体百度吧
在这里插入图片描述

生成网络的构建Generator

生成网络的目标主要是生成你想转换的那张目标
在这里插入图片描述
在生成网络中我们使用何凯明大神提出的resnet,其实包括图像复原,还有图像很多相关的就是利用先下采样,在上采样,对图片进行预测和修复。resnet我就不具体介绍:他主要利用了残差块,加深了网络结构,减少了梯度下降的情况。
在这里插入图片描述

具体代码如下:

import keras
from keras.models import *
from keras.layers import *
from keras import layers
import keras.backend as K
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization

IMAGE_ORDERING = 'channels_last'
def one_side_pad( x ):
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    if IMAGE_ORDERING == 'channels_first':
        x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x)
    elif IMAGE_ORDERING == 'channels_last':
        x = Lambda(lambda x : x[: , :-1 , :-1 , :  ] )(x)
    return x
 def identity_block(input_tensor, kernel_size, filter_num, block):


    conv_name_base = 'res' + block + '_branch'
    in_name_base = 'in' + block + '_branch'

    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(input_tensor)
    x = Conv2D(filter_num, (3, 3) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(x)
    x = InstanceNormalization(axis=3,name=in_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(filter_num , (3, 3), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x)
    x = InstanceNormalization(axis=3,name=in_name_base + '2c')(x)
    # 残差网络
    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x


def get_resnet(input_height, input_width, channel):
    img_input = Input(shape=(input_height,input_width , 3 ))
    # 128,128,3 -> 128,128,64
    x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
    x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)

    # 128,128,64 -> 64,64,128
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING, strides=2)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)

    # 64,64,128 -> 32,32,256
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, strides=2)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)

    for i in range(9):
        x = identity_block(x, 3, 256, block=str(i))
    
    # 32,32,256 -> 64,64,128
    x = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(x)
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)
    
    # 64,64,128 -> 128,128,64
    x = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(x)
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)    

    # 128,128,64 -> 128,128,3
    x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(channel, (7, 7), data_format=IMAGE_ORDERING)(x)
    x = Activation('tanh')(x)  
    model = Model(img_input,x)
    return model

判别网络Discriminator

判别网络主要是判别生成的图片的真假。
在这里插入图片描述
具体代码,如下:

    def build_discriminator(self):

        def conv2d(layer_input, filters, f_size=4, normalization=True):
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            if normalization:
                d = InstanceNormalization()(d)
            d = LeakyReLU(alpha=0.2)(d)
            return d

        img = Input(shape=self.img_shape)
        # 64,64,64
        d1 = conv2d(img, 64, normalization=False)
        # 32,32,128
        d2 = conv2d(d1, 128)
        # 16,16,256
        d3 = conv2d(d2, 256)
        # 8,8,512
        d4 = conv2d(d3, 512)
        # 对每个像素点判断是否有效
        # 64
        # 8,8,1
        validity = Conv2D(1, kernel_size=3, strides=1, padding='same')(d4)

        return Model(img, validity)

在判别网络中最后输出为881.可以理解为有64个评委对图片的真伪做出了打分。

扫描二维码关注公众号,回复: 11639194 查看本文章

loss计算

在判别模型Discriminator的loss,主要采用了均方差,这种计算方式据说可以提高预测的的准确性。训练的标签并不是配对的,而且我们利用网络引导生成的。在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

训练

  # 创建生成模型
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # 生成假的B图片
        fake_B = self.g_AB(img_A)
        # 生成假的A图片
        fake_A = self.g_BA(img_B)

        # 从B再生成A
        reconstr_A = self.g_BA(fake_B)
        # 从B再生成A
        reconstr_B = self.g_AB(fake_A)
        self.g_AB.summary()
        # 通过g_BA传入img_A
        img_A_id = self.g_BA(img_A)
        # 通过g_AB传入img_B
        img_B_id = self.g_AB(img_B)

        # 在这一部分,评价模型不训练。
        self.d_A.trainable = False
        self.d_B.trainable = False

        # 评价是否为真
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # 训练
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        reconstr_A, reconstr_B,
                                        img_A_id, img_B_id ])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[0.5, 0.5,
                                        self.lambda_cycle, self.lambda_cycle,
                                        self.lambda_id, self.lambda_id ],
                            optimizer=optimizer)

    def train(self, init_epoch, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        if init_epoch!= 0:
            self.d_A.load_weights("weights/%s/d_A_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
            self.d_B.load_weights("weights/%s/d_B_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
            self.g_AB.load_weights("weights/%s/g_AB_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
            self.g_BA.load_weights("weights/%s/g_BA_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)

        for epoch in range(init_epoch,epochs):
            self.scheduler([self.combined,self.d_A,self.d_B],epoch)
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
                # ------------------ #
                #  训练生成模型
                # ------------------ #
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                        [valid, valid,
                                                        imgs_A, imgs_B,
                                                        imgs_A, imgs_B])
                # ---------------------- #
                #  训练评价者
                # ---------------------- #
                # A到B的假图片,此时生成的是假橘子
                fake_B = self.g_AB.predict(imgs_A)
                # B到A的假图片,此时生成的是假苹果
                fake_A = self.g_BA.predict(imgs_B)
                # 判断真假图片,并以此进行训练
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
                # 判断真假图片,并以此进行训练
                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                d_loss = 0.5 * np.add(dA_loss, dB_loss)
                

                elapsed_time = datetime.datetime.now() - start_time

                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                        % ( epoch, epochs,
                                                                            batch_i, self.data_loader.n_batches,
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[0],
                                                                            np.mean(g_loss[1:3]),
                                                                            np.mean(g_loss[3:5]),
                                                                            np.mean(g_loss[5:6]),
                                                                            elapsed_time))

                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)
                    if epoch % 5 == 0 and epoch != init_epoch:
                        os.makedirs('weights/%s' % self.dataset_name, exist_ok=True)
                        self.d_A.save_weights("weights/%s/d_A_epoch%d.h5" % (self.dataset_name, epoch))
                        self.d_B.save_weights("weights/%s/d_B_epoch%d.h5" % (self.dataset_name, epoch))
                        self.g_AB.save_weights("weights/%s/g_AB_epoch%d.h5" % (self.dataset_name, epoch))
                        self.g_BA.save_weights("weights/%s/g_BA_epoch%d.h5" % (self.dataset_name, epoch))

猜你喜欢

转载自blog.csdn.net/qq_35914625/article/details/107969196