好像还挺好玩的GAN7——CycleGAN实现图像风格转换

学习前言

CycleGAN好像可以换皮肤,有必要了解一下。
在这里插入图片描述

什么是CycleGAN

CycleGAN是一种完成图像到图像的转换的一种GAN。

图像到图像的转换是一类视觉和图形问题,其目标是获得输入图像和输出图像之间的映射。

但是,对于许多任务,配对的训练数据将不可用。

CycleGAN提出了一种在没有成对例子的情况下学习将图像从源域X转换到目标域Y的方法。

在这里插入图片描述

keras_contrib库的Windows安装

由于CycleGAN要用到InstanceNormalization,这个函数在普通的keras内不存在,所以要安装一个新的库。

首先去github上下载https://github.com/keras-team/keras-contrib库,下载完后解压。
在这里插入图片描述
下载完成后利用cmd进入该文件夹,并利用python setup.py install进行安装。
在这里插入图片描述
安装完成后发现其实python库里面还是没有InstanceNormalization,于是我们需要把
在这里插入图片描述
这两个文件,移动到anaconda的环境中,具体要放到Lib下的site-packages中。
在这里插入图片描述
这样InstanceNormalization就可以正常使用了!

代码与训练数据的下载

这是我的github连接,代码可以在上面下载:
https://github.com/bubbliiiing/GAN-keras

这是斑马to黄种马的数据集下载:
https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
苹果to橘子数据集下载:
https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip
画作to照片数据集下载:
https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/monet2photo.zip

神经网络构建

1、Generator

生成网络的目标是:
输入一张图片,转化成自己期望的风格的那张图片。
就好像输入斑马,转化成一匹马。
在这里插入图片描述
输入一幅画作,转化成一张风景照片。
在这里插入图片描述
因此对于CycleGAN而言,它的输入是一张图像,输出也是一张图像。

这样的结构与我们所学过的语义分割的形式非常类似,因此需要先进行下采样后再进行上采样!

本文章的网络结构采用的与论文中的相同。

import keras
from keras.models import *
from keras.layers import *
from keras import layers
import keras.backend as K
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization

IMAGE_ORDERING = 'channels_last'
def one_side_pad( x ):
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    if IMAGE_ORDERING == 'channels_first':
        x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x)
    elif IMAGE_ORDERING == 'channels_last':
        x = Lambda(lambda x : x[: , :-1 , :-1 , :  ] )(x)
    return x

def identity_block(input_tensor, kernel_size, filter_num, block):


    conv_name_base = 'res' + block + '_branch'
    in_name_base = 'in' + block + '_branch'
    # 1x1压缩
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(input_tensor)
    x = Conv2D(filter_num, (3, 3) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(x)
    x = InstanceNormalization(axis=3,name=in_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(filter_num , (3, 3), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x)
    x = InstanceNormalization(axis=3,name=in_name_base + '2c')(x)
    # 残差网络
    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x


def get_resnet(input_height, input_width, channel):
    img_input = Input(shape=(input_height,input_width , 3 ))

    x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
    x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)
	# 下采样
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING, strides=2)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, strides=2)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)
	# 原论文中这个残差部分进行了9次,但是我的1660ti+6gb因为显存跑不起来
	# 于是我改成了6次
	# Keras比较吃显存,pytorch版本的可以跑。
    for i in range(6):
        x = identity_block(x, 3, 256, block=str(i))
	# 上采样
    x = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(x)
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)
    
    x = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(x)
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)    

    x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(channel, (7, 7), data_format=IMAGE_ORDERING)(x)
    x = Activation('tanh')(x)  
    model = Model(img_input,x)
    return model

2、Discriminator

判别模型的目的是根据输入的图片判断出真伪
在CycleGAN中,Dicriminator的训练的loss函数使用的是LSGAN中所提到均方差,这种loss可以提高假图像的精度。
在这里最后卷积完后的shape为(16,16,1)。

普通判别模型是全连接到一个1维向量上,而这里却不是,要怎么理解呢?

可以这样去感性理解这个操作:
一共有16x16个评委对这个图片进行判断,这其中可能有些评委判断错了,有些评委判断对了,根据每个评委判断的内容可以对判别模型进行修正
在这里插入图片描述

def build_discriminator(self):

    def conv2d(layer_input, filters, f_size=4, normalization=True):
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        if normalization:
            d = InstanceNormalization()(d)
        d = LeakyReLU(alpha=0.2)(d)
        return d

    img = Input(shape=self.img_shape)

    d1 = conv2d(img, 64, normalization=False)
    d2 = conv2d(d1, 128)
    d3 = conv2d(d2, 256)
    d4 = conv2d(d3, 512)
    # 对每个像素点判断是否有效
    validity = Conv2D(1, kernel_size=3, strides=1, padding='same')(d4)

    return Model(img, validity)

训练思路

CycleGAN的训练思路分为如下几个步骤:
1、创建两个生成模型,一个用于从图片风格A转换成图片风格B,一个用于从图片风格B转换成图片风格A。
2、创建两个判别模型,分别用于风格A图片的真伪判断和风格B图片的真伪判断。
3、判别模型的训练所用的损失函数与LSGAN相同,通过判断是否正确进行训练。
4、生成模型的训练需要满足下面六个准则:

  • a、从图片风格A转换成图片风格B的假图像需要成功欺骗判断模型B;
  • b、从图片风格B转换成图片风格A的假图像需要成功欺骗判断模型A;
  • c、从图片风格A转换成图片风格B的假图像可以通过生成模型BA成功转换成图片A;
  • d、从图片风格B转换成图片风格A的假图像可以通过生成模型AB成功转换成图片B;
  • e、真实图片A通过生成模型BA,不会发生变化。
  • f、真实图片B通过生成模型AB,不会发生变化。

其中c、d准则是为了让生成器找到最需要修改的地方,比如 斑马转黄马就只要改变马的颜色就可以欺骗判断模型,风格A的图片经过生成模型AB只需要转化 斑马 即可。
其中e、f准则是为了让 两种生成模型可以区分两种图片风格,生成模型AB只对风格A的图片进行处理,生成模型BA只对风格B的图片进行处理。

实现全部代码

1、网络部分全部代码

用于存放网络:

import keras
from keras.models import *
from keras.layers import *
from keras import layers
import keras.backend as K
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization

IMAGE_ORDERING = 'channels_last'
def one_side_pad( x ):
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    if IMAGE_ORDERING == 'channels_first':
        x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x)
    elif IMAGE_ORDERING == 'channels_last':
        x = Lambda(lambda x : x[: , :-1 , :-1 , :  ] )(x)
    return x

def identity_block(input_tensor, kernel_size, filter_num, block):


    conv_name_base = 'res' + block + '_branch'
    in_name_base = 'in' + block + '_branch'
    # 1x1压缩
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(input_tensor)
    x = Conv2D(filter_num, (3, 3) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(x)
    x = InstanceNormalization(axis=3,name=in_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(filter_num , (3, 3), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x)
    x = InstanceNormalization(axis=3,name=in_name_base + '2c')(x)
    # 残差网络
    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x


def get_resnet(input_height, input_width, channel):
    img_input = Input(shape=(input_height,input_width , 3 ))

    x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
    x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING, strides=2)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, strides=2)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)

    for i in range(6):
        x = identity_block(x, 3, 256, block=str(i))

    x = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(x)
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)
    
    x = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(x)
    x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING)(x)
    x = InstanceNormalization(axis=3)(x)
    x = Activation('relu')(x)    

    x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(x)
    x = Conv2D(channel, (7, 7), data_format=IMAGE_ORDERING)(x)
    x = Activation('tanh')(x)  
    model = Model(img_input,x)
    return model

2、data_loader全部代码

该部分用于对数据进行加载:

import cv2
from glob import glob
import numpy as np

class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = cv2.resize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = cv2.resize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))

        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = cv2.resize(img_A, self.img_res)
                img_B = cv2.resize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A,dtype=np.float32)/127.5 - 1.
            imgs_B = np.array(imgs_B,dtype=np.float32)/127.5 - 1.

            yield imgs_A, imgs_B

    def load_img(self, path):
        img = self.imread(path)
        img = img = cv2.resize(img, self.img_res)
        img = img/127.5 - 1.
        return img[np.newaxis, :, :, :]

    def imread(self, path):
        img = cv2.imread(path)
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        return img

3、主函数全部代码

训练代码

from __future__ import print_function, division
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.backend.tensorflow_backend import set_session
from keras.optimizers import Adam
from keras import backend as K
from keras.layers import *
from keras.models import *
import keras

from nets.resnet import get_resnet
from data_loader import DataLoader
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import datetime
import sys
import os

class CycleGAN():
    def __init__(self):
        config = tf.ConfigProto()
        config.gpu_options.allocator_type = 'BFC' #A "Best-fit with coalescing" algorithm, simplified from a version of dlmalloc.
        config.gpu_options.per_process_gpu_memory_fraction = 0.7
        config.gpu_options.allow_growth = True
        set_session(tf.Session(config=config)) 
        # 输入图片的大小为256x256x3
        self.img_rows = 256
        self.img_cols = 256
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        # 载入数据
        self.dataset_name = 'horse2zebra'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))

        # patch
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Loss weights
        self.lambda_cycle = 5
        
        self.lambda_id = 2.5

        optimizer = Adam(0.0002, 0.5)

        #-------------------------#
        #   建立判别网络
        #-------------------------#
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()

        self.d_A.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.d_B.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        #-------------------------#
        #   创建生成模型
        #-------------------------#

        # 创建生成模型
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # 生成假的B图片
        fake_B = self.g_AB(img_A)
        # 生成假的A图片
        fake_A = self.g_BA(img_B)

        # 从B再生成A
        reconstr_A = self.g_BA(fake_B)
        # 从B再生成A
        reconstr_B = self.g_AB(fake_A)
        self.g_AB.summary()
        # 通过g_BA传入img_A
        img_A_id = self.g_BA(img_A)
        # 通过g_AB传入img_B
        img_B_id = self.g_AB(img_B)

        # 在这一部分,评价模型不训练。
        self.d_A.trainable = False
        self.d_B.trainable = False

        # 评价是否为真
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # 训练
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        reconstr_A, reconstr_B,
                                        img_A_id, img_B_id ])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[0.5, 0.5,
                                        self.lambda_cycle, self.lambda_cycle,
                                        self.lambda_id, self.lambda_id ],
                            optimizer=optimizer)

    def build_generator(self):
        model = get_resnet(self.img_rows,self.img_cols,self.channels)
        return model

    def build_discriminator(self):

        def conv2d(layer_input, filters, f_size=4, normalization=True):
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            if normalization:
                d = InstanceNormalization()(d)
            d = LeakyReLU(alpha=0.2)(d)
            return d

        img = Input(shape=self.img_shape)
        # 128
        d1 = conv2d(img, 64, normalization=False)
        # 64
        d2 = conv2d(d1, 128)
        # 32
        d3 = conv2d(d2, 256)
        # 16
        d4 = conv2d(d3, 512)
        # 对每个像素点判断是否有效
        validity = Conv2D(1, kernel_size=3, strides=1, padding='same')(d4)

        return Model(img, validity)
    
    def scheduler(self,models,epoch):
        # 每隔100个epoch,学习率减小为原来的1/2
        if epoch % 20 == 0 and epoch != 0:
            for model in models:
                lr = K.get_value(model.optimizer.lr)
                K.set_value(model.optimizer.lr, lr * 0.5)
            print("lr changed to {}".format(lr * 0.5))
        
    def train(self, init_epoch, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        if init_epoch!= 0:
            self.d_A.load_weights("weights/%s/d_A_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
            self.d_B.load_weights("weights/%s/d_B_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
            self.g_AB.load_weights("weights/%s/g_AB_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
            self.g_BA.load_weights("weights/%s/g_BA_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)

        for epoch in range(init_epoch,epochs):
            self.scheduler([self.combined,self.d_A,self.d_B],epoch)
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
                # ------------------ #
                #  训练生成模型
                # ------------------ #
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                        [valid, valid,
                                                        imgs_A, imgs_B,
                                                        imgs_A, imgs_B])
                # ---------------------- #
                #  训练评价者
                # ---------------------- #
                # A到B的假图片,此时生成的是假橘子
                fake_B = self.g_AB.predict(imgs_A)
                # B到A的假图片,此时生成的是假苹果
                fake_A = self.g_BA.predict(imgs_B)
                # 判断真假图片,并以此进行训练
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
                # 判断真假图片,并以此进行训练
                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                d_loss = 0.5 * np.add(dA_loss, dB_loss)
                

                elapsed_time = datetime.datetime.now() - start_time

                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                        % ( epoch, epochs,
                                                                            batch_i, self.data_loader.n_batches,
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[0],
                                                                            np.mean(g_loss[1:3]),
                                                                            np.mean(g_loss[3:5]),
                                                                            np.mean(g_loss[5:6]),
                                                                            elapsed_time))

                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)
                    if epoch % 5 == 0 and epoch != init_epoch:
                        os.makedirs('weights/%s' % self.dataset_name, exist_ok=True)
                        self.d_A.save_weights("weights/%s/d_A_epoch%d.h5" % (self.dataset_name, epoch))
                        self.d_B.save_weights("weights/%s/d_B_epoch%d.h5" % (self.dataset_name, epoch))
                        self.g_AB.save_weights("weights/%s/g_AB_epoch%d.h5" % (self.dataset_name, epoch))
                        self.g_BA.save_weights("weights/%s/g_BA_epoch%d.h5" % (self.dataset_name, epoch))

    def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 2

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)

        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)

        gen_imgs = np.concatenate([imgs_A, fake_B, imgs_B, fake_A])

        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()

if __name__ == '__main__':
    gan = CycleGAN()
    gan.train(init_epoch=15, epochs=200, batch_size=1, sample_interval=200)

实现效果:

下图为训练了50代的效果(因为减少了残差模块的次数,而且训练的不够多,导致效果没有那么好):
在这里插入图片描述

发布了167 篇原创文章 · 获赞 112 · 访问量 24万+

猜你喜欢

转载自blog.csdn.net/weixin_44791964/article/details/103780922