常用数据增强系神奇gan-CycleGan
学习目录
学习前言
写了一天代码,累加懒。本来今天想解读下下yolov4的讲解的,感觉要打太多,就有空写吧。昨天突然产生出灵感,讲讲一些论文里面常提出来的用gan来数据增强吧(改进思路),cycleGan,它的思想也是我也挺喜欢。
github
https://github.com/yanjingke/cyclegan
那么什么是CycleGan?
许多名画造假者费尽毕生的心血,试图模仿出艺术名家的风格。CycleGAN就可以初步实现这个神奇的功能。这个功能就是风格迁移。
CircleGan的安装
因为由于CycleGAN要用到InstanceNormalization,这个函数在普通的keras内不存在,所以要安装一个新的库。
首先去github上下载https://github.com/keras-team/keras-contrib库,下载完后解压。cmd 里面进入目录python setup.py install。github下载慢的话,可以把项目clone到码云里面,下载。具体百度吧
生成网络的构建Generator
生成网络的目标主要是生成你想转换的那张目标
在生成网络中我们使用何凯明大神提出的resnet,其实包括图像复原,还有图像很多相关的就是利用先下采样,在上采样,对图片进行预测和修复。resnet我就不具体介绍:他主要利用了残差块,加深了网络结构,减少了梯度下降的情况。
具体代码如下:
import keras
from keras.models import *
from keras.layers import *
from keras import layers
import keras.backend as K
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
IMAGE_ORDERING = 'channels_last'
def one_side_pad( x ):
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
if IMAGE_ORDERING == 'channels_first':
x = Lambda(lambda x : x[: , : , :-1 , :-1 ] )(x)
elif IMAGE_ORDERING == 'channels_last':
x = Lambda(lambda x : x[: , :-1 , :-1 , : ] )(x)
return x
def identity_block(input_tensor, kernel_size, filter_num, block):
conv_name_base = 'res' + block + '_branch'
in_name_base = 'in' + block + '_branch'
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(input_tensor)
x = Conv2D(filter_num, (3, 3) , data_format=IMAGE_ORDERING , name=conv_name_base + '2a')(x)
x = InstanceNormalization(axis=3,name=in_name_base + '2a')(x)
x = Activation('relu')(x)
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
x = Conv2D(filter_num , (3, 3), data_format=IMAGE_ORDERING , name=conv_name_base + '2c')(x)
x = InstanceNormalization(axis=3,name=in_name_base + '2c')(x)
# 残差网络
x = layers.add([x, input_tensor])
x = Activation('relu')(x)
return x
def get_resnet(input_height, input_width, channel):
img_input = Input(shape=(input_height,input_width , 3 ))
# 128,128,3 -> 128,128,64
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING)(x)
x = InstanceNormalization(axis=3)(x)
x = Activation('relu')(x)
# 128,128,64 -> 64,64,128
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
x = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING, strides=2)(x)
x = InstanceNormalization(axis=3)(x)
x = Activation('relu')(x)
# 64,64,128 -> 32,32,256
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
x = Conv2D(256, (3, 3), data_format=IMAGE_ORDERING, strides=2)(x)
x = InstanceNormalization(axis=3)(x)
x = Activation('relu')(x)
for i in range(9):
x = identity_block(x, 3, 256, block=str(i))
# 32,32,256 -> 64,64,128
x = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(x)
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
x = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING)(x)
x = InstanceNormalization(axis=3)(x)
x = Activation('relu')(x)
# 64,64,128 -> 128,128,64
x = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(x)
x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
x = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING)(x)
x = InstanceNormalization(axis=3)(x)
x = Activation('relu')(x)
# 128,128,64 -> 128,128,3
x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(x)
x = Conv2D(channel, (7, 7), data_format=IMAGE_ORDERING)(x)
x = Activation('tanh')(x)
model = Model(img_input,x)
return model
判别网络Discriminator
判别网络主要是判别生成的图片的真假。
具体代码,如下:
def build_discriminator(self):
def conv2d(layer_input, filters, f_size=4, normalization=True):
d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
if normalization:
d = InstanceNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
return d
img = Input(shape=self.img_shape)
# 64,64,64
d1 = conv2d(img, 64, normalization=False)
# 32,32,128
d2 = conv2d(d1, 128)
# 16,16,256
d3 = conv2d(d2, 256)
# 8,8,512
d4 = conv2d(d3, 512)
# 对每个像素点判断是否有效
# 64
# 8,8,1
validity = Conv2D(1, kernel_size=3, strides=1, padding='same')(d4)
return Model(img, validity)
在判别网络中最后输出为881.可以理解为有64个评委对图片的真伪做出了打分。
扫描二维码关注公众号,回复:
11639194 查看本文章
loss计算
在判别模型Discriminator的loss,主要采用了均方差,这种计算方式据说可以提高预测的的准确性。训练的标签并不是配对的,而且我们利用网络引导生成的。
训练
# 创建生成模型
self.g_AB = self.build_generator()
self.g_BA = self.build_generator()
img_A = Input(shape=self.img_shape)
img_B = Input(shape=self.img_shape)
# 生成假的B图片
fake_B = self.g_AB(img_A)
# 生成假的A图片
fake_A = self.g_BA(img_B)
# 从B再生成A
reconstr_A = self.g_BA(fake_B)
# 从B再生成A
reconstr_B = self.g_AB(fake_A)
self.g_AB.summary()
# 通过g_BA传入img_A
img_A_id = self.g_BA(img_A)
# 通过g_AB传入img_B
img_B_id = self.g_AB(img_B)
# 在这一部分,评价模型不训练。
self.d_A.trainable = False
self.d_B.trainable = False
# 评价是否为真
valid_A = self.d_A(fake_A)
valid_B = self.d_B(fake_B)
# 训练
self.combined = Model(inputs=[img_A, img_B],
outputs=[ valid_A, valid_B,
reconstr_A, reconstr_B,
img_A_id, img_B_id ])
self.combined.compile(loss=['mse', 'mse',
'mae', 'mae',
'mae', 'mae'],
loss_weights=[0.5, 0.5,
self.lambda_cycle, self.lambda_cycle,
self.lambda_id, self.lambda_id ],
optimizer=optimizer)
def train(self, init_epoch, epochs, batch_size=1, sample_interval=50):
start_time = datetime.datetime.now()
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
if init_epoch!= 0:
self.d_A.load_weights("weights/%s/d_A_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
self.d_B.load_weights("weights/%s/d_B_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
self.g_AB.load_weights("weights/%s/g_AB_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
self.g_BA.load_weights("weights/%s/g_BA_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
for epoch in range(init_epoch,epochs):
self.scheduler([self.combined,self.d_A,self.d_B],epoch)
for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
# ------------------ #
# 训练生成模型
# ------------------ #
g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
[valid, valid,
imgs_A, imgs_B,
imgs_A, imgs_B])
# ---------------------- #
# 训练评价者
# ---------------------- #
# A到B的假图片,此时生成的是假橘子
fake_B = self.g_AB.predict(imgs_A)
# B到A的假图片,此时生成的是假苹果
fake_A = self.g_BA.predict(imgs_B)
# 判断真假图片,并以此进行训练
dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
# 判断真假图片,并以此进行训练
dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
d_loss = 0.5 * np.add(dA_loss, dB_loss)
elapsed_time = datetime.datetime.now() - start_time
print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
% ( epoch, epochs,
batch_i, self.data_loader.n_batches,
d_loss[0], 100*d_loss[1],
g_loss[0],
np.mean(g_loss[1:3]),
np.mean(g_loss[3:5]),
np.mean(g_loss[5:6]),
elapsed_time))
if batch_i % sample_interval == 0:
self.sample_images(epoch, batch_i)
if epoch % 5 == 0 and epoch != init_epoch:
os.makedirs('weights/%s' % self.dataset_name, exist_ok=True)
self.d_A.save_weights("weights/%s/d_A_epoch%d.h5" % (self.dataset_name, epoch))
self.d_B.save_weights("weights/%s/d_B_epoch%d.h5" % (self.dataset_name, epoch))
self.g_AB.save_weights("weights/%s/g_AB_epoch%d.h5" % (self.dataset_name, epoch))
self.g_BA.save_weights("weights/%s/g_BA_epoch%d.h5" % (self.dataset_name, epoch))