1. Project Background
Since the development of GAN technology, many novel applications have been implemented. This time we will introduce one of the small applications, that is, coloring black and white images. Presumably you often see some videos of coloring black and white images on the Internet, which can repair early black and white images. The technology behind it is GAN. The following animation is a case map of our actual project. Colorize.
2. Introduction to the principle
This time we will implement the coloring of black and white images based on Pix2Pix. Pix2Pix is a general framework that can be applied to any image translation task. Let's first introduce its principle.
Its generator and discriminator inputs are not noise, but real images. After the input image x passes through the generator G, the generated image G(x) is obtained, and then G(x) and x are used as the input of the discriminator on one side, and the real label y and x are used as the input of the discriminator on the other side. The former discriminator The output is fake, the latter is real.
G is a common codec structure, and D is an ordinary classifier. What are the advantages of such a generative framework?
The authors believe that the general encoder-decoder structure can solve the generation of low-frequency components, but the details of high-frequency components are not ideal, while GAN is good at solving the generation of high-frequency components. The total generator loss function is a standard conditional GAN loss plus L1 reconstruction loss, respectively defined as follows:
For the specific implementation, we look at the following code.
3. Data preprocessing
For image coloring tasks, the CIELab color space will have better results than the RGB color space, because the L channel in the CIELab color space only has grayscale information, while the a and b channels only have color information, which realizes the brightness and Separation of colors.
The figure below shows the color distribution in CIELab Color, which has more linear and uniform distribution characteristics than other color spaces.
Therefore, in the data reading module, it is necessary to convert the RGB image to the CIELab color space, and then construct the paired data. Let's take a look at the core functions in the data reading class, including the initialization function __init__ and the data iterator __getitem__.
The data class is defined as follows
class ColorizationDataset(BaseDataset):
def __init__(self, opt):
BaseDataset.__init__(self, opt)
self.dir = os.path.join(opt.dataroot, opt.phase)
self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))
assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')
self.transform = get_transform(self.opt, convert=False)
def __getitem__(self, index):
path = self.AB_paths[index]
im = Image.open(path).convert('RGB') ## 读取RGB图
im = self.transform(im) ## 进行预处理
im = np.array(im)
lab = color.rgb2lab(im).astype(np.float32) ## 将RGB图转换为CIELab图
lab_t = transforms.ToTensor()(lab)
L = lab_t[[0], ...] / 50.0 - 1.0 ## 将L通道(index=0)数值归一化到-1到1之间
AB = lab_t[[1, 2], ...] / 110.0 ## 将A,B通道(index=1,2)数值归一化到0到1之间
return {'A': L, 'B': AB, 'A_paths': path, 'B_paths': path}
In the above __getitem__ function, the PIL package is first used to read the image, and then it is converted into CIELab space after preprocessing. The value range of the L channel after reading is between 0 and 100, and is normalized to between -1 and 1 after processing. The value range of the channels of A and B after reading is between 0 and 110, and normalized to between 0 and 1 after processing.
In addition, preprocessing is performed in the __init__ function, and the get_transform function is called, which mainly includes operations such as image scaling, random cropping, random flipping, subtracting the mean and dividing by the variance. Since it is a more general operation, the key code to interpret.
4. Generator Network
The generator uses the U-Net structure, and the residual structure can also be selected in this open source framework, but we use U-Net to complete the experimental tasks
The UNet generator is defined as follows
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetGenerator, self).__init__()
unet_block = UnetSkipConnectionBlock(ngf*8,ngf*8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
for i in range(num_downs - 5):
unet_block=UnetSkipConnectionBlock(ngf*8,ngf*8,input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
## 逐步减小通道数,从ngf * 8到ngf
unet_block=UnetSkipConnectionBlock(ngf*4,ngf*8,input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block=UnetSkipConnectionBlock(ngf*2,ngf*4,input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block=UnetSkipConnectionBlock(ngf,ngf*2,input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model=UnetSkipConnectionBlock(output_nc,ngf,input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) ## 最外层
def forward(self, input):
"""Standard forward"""
return self.model(input)
The important parameters are explained as follows: input_nc is the input channel, output_nc is the output channel, num_downs is the number of downsampling, which is 2^num_downs, ngf is the number of channels in the last layer, and norm_layer is the normalization layer.
UnetSkipConnectionBlock是跳层连接的模块,它的定义如下:
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None,outermost=False,innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
##是否使用dropout
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:#最外层直接输出
return self.model(x)
else:#添加跳层
return torch.cat([x, self.model(x)], 1)
Where outer_nc is the number of outer channels, inner_nc is the number of inner channels, input_nc is the number of input channels, submodule is the previous submodule, outermost determines whether it is the outermost layer, innermost determines whether it is the innermost layer, and norm_layer is the normalization layer , user_dropout is whether to use dropout.
For the pix2pix model, the normalization layer used defaults to nn.BatchNorm2d, which is effectively equivalent to InstanceNorm when batch=1.
5. Discriminator Definition
Next, let's look at the definition of the discriminator. The discriminator is a classification model, but as we said earlier, its input is not the whole picture, but an image block, so the output of the discriminator is not a single number, but The probability map of multiple image blocks, which are finally added to obtain the complete probability, which is defined as follows:
PatchGAN is defined as follows
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: ##判断归一化层类别,如果是BN则不需要bias
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4 ##卷积核大小
padw = 1 ##填充大小
## 第一个卷积层
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
## 中间2个卷积层
for n in range(1, n_layers): ##逐渐增加通道宽度,每次扩充为原来两倍
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
## 第五个卷积层
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
## 输出单通道预测结果图
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
where input_nc is the input graph channel, ndf is the last convolutional layer channel, n_layers is the number of discriminator layers, and norm_layer is the normalization layer type. As can be seen from the code, 5 layers of convolution are included by default, in which the size of the convolution kernel is 4 4, the stride=2 of the first 3 layers, the stride=1 of the last two convolution layers, and the total receptive field is 70 70 , which is also the reason why the block of 70*70 is actually taken as mentioned above. The input and output of each layer and the statistics of the receptive field are as follows:
6. Loss function definition
Next we look at the definition of the loss function.
class GANLoss(nn.Module):
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
#gan_model,损失类型,支持原始损失,lsgan,wgangp
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
#将标签转为与预测结果图同样大小
def get_target_tensor(self, prediction, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
##返回损失
def __call__(self, prediction, target_is_real):
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
The above code implements the calculation of several common GAN adversarial losses.
7. Complete structure definition
After defining the discriminator and generator, let's look at the definition of the complete pix2pix model, as follows:
class Pix2PixModel(BaseModel):
##配置默认参数
def modify_commandline_options(parser, is_train=True):
##默认使用batchnorm,网络结构为unet_256,使用成对的(aligned)图片数据集
parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
if is_train:
parser.set_defaults(pool_size=0, gan_mode='vanilla')#使用经典GAN损失
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')#L1损失权重为100
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] ##损失
self.visual_names = ['real_A', 'fake_B', 'real_B'] ##中间结果图
if self.isTrain:
self.model_names = ['G', 'D']
else: # during test time, only load G
self.model_names = ['G']
#生成器和判别器定义
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
# 判别器定义,输入RGB图和生成器图的拼接
if self.isTrain:
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain:
#损失函数定义,GAN标准损失和L1重建损失
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionL1 = torch.nn.L1Loss()
# 优化器,使用Adam
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
## 输入预处理,根据不同方向进行A,B的设置
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
# 生成器前向传播
def forward(self):
self.fake_B = self.netG(self.real_A) #G(A)
# 判别器损失
def backward_D(self):
#假样本损失
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
#真样本损失
real_AB = torch.cat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
#真样本和假样本损失平均
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
# 生成器损失
def backward_G(self):
# GAN损失
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
#重建损失
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
#损失加权平均
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
def optimize_parameters(self):
self.forward() # 计算G(A)
# 更新D
self.set_requires_grad(self.netD, True) #D
self.optimizer_D.zero_grad() #D梯度清零
self.backward_D() #计算 D梯度
self.optimizer_D.step() #更新D权重
# 更新G
self.set_requires_grad(self.netD, False) #优化G时无须迭代D
self.optimizer_G.zero_grad() # G梯度清零
self.backward_G() # 计算 G梯度
self.optimizer_G.step() #更新G权重
The above completes the interpretation of the core code in the project. Next, we train and test the model.
8. Dataset Preparation
First we prepare the color dataset A
Then we prepare the black and white dataset B
We choose the direction of training as B to A
8. Model training
Model training is to complete model definition, data loading, visualization and storage. The core code is as follows:
if __name__ == '__main__':
opt = TrainOptions().parse() #获取一些训练参数
dataset = create_dataset(opt) #创建数据集
dataset_size = len(dataset) #数据集大小
print('The number of training images = %d' % dataset_size)
model = create_model(opt) #创建模型
model.setup(opt) #模型初始化
visualizer = Visualizer(opt) #可视化函数
total_iters = 0 #迭代batch次数
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_iter = 0 #当前epoch迭代batch数
for i, data in enumerate(dataset): #每一个epoch内层循环
visualizer.reset()
total_iters += opt.batch_size #总迭代batch数
epoch_iter += opt.batch_size
model.set_input(data) #输入数据
model.optimize_parameters() #迭代更新
if total_iters % opt.display_freq == 0: #visdom可视化
save_result = total_iters % opt.update_html_freq == 0
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_iters % opt.print_freq == 0: #存储损失等信息
losses = model.get_current_losses()
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
if total_iters % opt.save_latest_freq == 0: #存储模型
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
if epoch % opt.save_epoch_freq == 0: #每隔opt.save_epoch_freq各epoch存储模型
model.save_networks('latest')
model.save_networks(epoch)
model.update_learning_rate()#每一个epoch后更新学习率
其中的一些重要训练参数配置如下:
input_nc=1,表示生成器输入为1通道图像,即L通道。
output_nc=2,表示生成器输出为2通道图像,即AB通道。
ngf=64,表示生成器最后1个卷积层输出通道为64。
ndf=64,表示判别器最后1个卷积层输出通道为64。
n_layers_D=3,表示使用默认的PatchGAN,它相当于对70×70大小的图像块进行判别。
norm=batch,batch_size=1,表示使用批次标准化。
load_size=286,表示载入的图像尺寸。
crop_size=256,表示图像裁剪即训练尺寸。
9. Effect display
10. Project video display
11. Project Integration
12. Complete source code & environment deployment video tutorial & dataset:
Baidu Bread can download the source code by searching for the title name
13. References
- [1] T-Reader: A Multi-task Deep Reading Comprehension Model Based on Self-Attention Mechanism [J]. Zheng Yukun, Li Dan, Fan Zhen, Liu Yiqun, Zhang Min, Ma Shaoping. Chinese Journal of Information. 2018(11)
- [2] Road condition video frame prediction model using residual generative adversarial network [J]. Yuan Shuai, Qin Guihe, Yan Jie. Journal of Xi'an Jiaotong University. 2018(10)
- [3] Image Recognition Method Based on Conditional Deep Convolution Generative Adversarial Networks [J]. Tang Xianlun, Du Yiming, Liu Yuwei, Li Jiaxin, Ma Yiwei. Journal of Automation. 2018(05)
- [4] Research on Face Restoration Algorithm Based on Semi-Supervised Learning Generative Adversarial Network [J]. Cao Zhiyi, Niu Shaozhang, Zhang Jiwei. Journal of Electronics and Information. 2018(02)
- [5] Example application of improving convolutional neural network in classification and recommendation [J]. Yang Tianqi, Huang Shuangxi. Computer Application Research. 2018(04)
- [6] Ink image rendering simulation of real image conversion [J]. Chen Tianding, Jin Weiwei, Chen Yingdan, Wu Di. Chinese Journal of Image Graphics. 2014(06)
- [7] Coloring method of black and white cartoon images with color preservation [J]. Zhu Wei, Liu Ligang. Journal of Computer Aided Design and Graphics. 2011(03)
- [8] Stylized rendering algorithm for image oil painting based on visual importance [J]. Lu Shaoping, Zhang Songhai. Journal of Computer Aided Design and Graphics. 2010(07)
- [9] An Image Quality Evaluation Method Based on Local Variance and Structural Similarity [J]. Wang Yuqing, Liu Weiya, Wang Yong. Optoelectronics. Laser. 2008(11)
- [10] Evaluation of color image quality based on mean square error and peak signal-to-noise ratio based on chromatic aberration [J]. Huang Xiaoqiao, Shi Junsheng, Yang Jian, Yao Juncai. Acta Photonica Sinica. 2007(S1)