Python based on CycleGAN & pix2pix AI coloring of black and white images (source code & deployment tutorial)

1. Project Background

Since the development of GAN technology, many novel applications have been implemented. This time we will introduce one of the small applications, that is, coloring black and white images. Presumably you often see some videos of coloring black and white images on the Internet, which can repair early black and white images. The technology behind it is GAN. The following animation is a case map of our actual project. Colorize.
2.png

2. Introduction to the principle

This time we will implement the coloring of black and white images based on Pix2Pix. Pix2Pix is ​​a general framework that can be applied to any image translation task. Let's first introduce its principle.
6.png
Its generator and discriminator inputs are not noise, but real images. After the input image x passes through the generator G, the generated image G(x) is obtained, and then G(x) and x are used as the input of the discriminator on one side, and the real label y and x are used as the input of the discriminator on the other side. The former discriminator The output is fake, the latter is real.
G is a common codec structure, and D is an ordinary classifier. What are the advantages of such a generative framework?
The authors believe that the general encoder-decoder structure can solve the generation of low-frequency components, but the details of high-frequency components are not ideal, while GAN is good at solving the generation of high-frequency components. The total generator loss function is a standard conditional GAN ​​loss plus L1 reconstruction loss, respectively defined as follows:
7.png

For the specific implementation, we look at the following code.

3. Data preprocessing

For image coloring tasks, the CIELab color space will have better results than the RGB color space, because the L channel in the CIELab color space only has grayscale information, while the a and b channels only have color information, which realizes the brightness and Separation of colors.

The figure below shows the color distribution in CIELab Color, which has more linear and uniform distribution characteristics than other color spaces.
8.png

Therefore, in the data reading module, it is necessary to convert the RGB image to the CIELab color space, and then construct the paired data. Let's take a look at the core functions in the data reading class, including the initialization function __init__ and the data iterator __getitem__.

The data class is defined as follows

class ColorizationDataset(BaseDataset):

   def __init__(self, opt):

        BaseDataset.__init__(self, opt)

        self.dir = os.path.join(opt.dataroot, opt.phase)

        self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))

        assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')

        self.transform = get_transform(self.opt, convert=False)

    def __getitem__(self, index):

        path = self.AB_paths[index]

        im = Image.open(path).convert('RGB') ## 读取RGB图

        im = self.transform(im) ## 进行预处理

        im = np.array(im)

        lab = color.rgb2lab(im).astype(np.float32) ## 将RGB图转换为CIELab图

        lab_t = transforms.ToTensor()(lab)

        L = lab_t[[0], ...] / 50.0 - 1.0 ## 将L通道(index=0)数值归一化到-1到1之间

        AB = lab_t[[1, 2], ...] / 110.0 ## 将A,B通道(index=1,2)数值归一化到0到1之间

        return {'A': L, 'B': AB, 'A_paths': path, 'B_paths': path}

In the above __getitem__ function, the PIL package is first used to read the image, and then it is converted into CIELab space after preprocessing. The value range of the L channel after reading is between 0 and 100, and is normalized to between -1 and 1 after processing. The value range of the channels of A and B after reading is between 0 and 110, and normalized to between 0 and 1 after processing.

In addition, preprocessing is performed in the __init__ function, and the get_transform function is called, which mainly includes operations such as image scaling, random cropping, random flipping, subtracting the mean and dividing by the variance. Since it is a more general operation, the key code to interpret.

4. Generator Network

The generator uses the U-Net structure, and the residual structure can also be selected in this open source framework, but we use U-Net to complete the experimental tasks

The UNet generator is defined as follows

class UnetGenerator(nn.Module):

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):

 super(UnetGenerator, self).__init__()

        unet_block = UnetSkipConnectionBlock(ngf*8,ngf*8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer

        for i in range(num_downs - 5):        

            unet_block=UnetSkipConnectionBlock(ngf*8,ngf*8,input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)

            ## 逐步减小通道数,从ngf * 8到ngf

            unet_block=UnetSkipConnectionBlock(ngf*4,ngf*8,input_nc=None, submodule=unet_block, norm_layer=norm_layer)

            unet_block=UnetSkipConnectionBlock(ngf*2,ngf*4,input_nc=None, submodule=unet_block, norm_layer=norm_layer)

            unet_block=UnetSkipConnectionBlock(ngf,ngf*2,input_nc=None, submodule=unet_block, norm_layer=norm_layer)

            self.model=UnetSkipConnectionBlock(output_nc,ngf,input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) ## 最外层

    def forward(self, input):

        """Standard forward"""

        return self.model(input)

The important parameters are explained as follows: input_nc is the input channel, output_nc is the output channel, num_downs is the number of downsampling, which is 2^num_downs, ngf is the number of channels in the last layer, and norm_layer is the normalization layer.

UnetSkipConnectionBlock是跳层连接的模块,它的定义如下:

class UnetSkipConnectionBlock(nn.Module):

    def __init__(self, outer_nc, inner_nc, input_nc=None,

                 submodule=None,outermost=False,innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):

        super(UnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost

        if type(norm_layer) == functools.partial:

            use_bias = norm_layer.func == nn.InstanceNorm2d

        else:

            use_bias = norm_layer == nn.InstanceNorm2d

        if input_nc is None:

            input_nc = outer_nc

        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,

                             stride=2, padding=1, bias=use_bias)

        downrelu = nn.LeakyReLU(0.2, True)

        downnorm = norm_layer(inner_nc)

        uprelu = nn.ReLU(True)

        upnorm = norm_layer(outer_nc)

        if outermost:

            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,

                                        kernel_size=4, stride=2,

                                        padding=1)

            down = [downconv]

            up = [uprelu, upconv, nn.Tanh()]

            model = down + [submodule] + up

        elif innermost:

            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,

                                        kernel_size=4, stride=2,

                                        padding=1, bias=use_bias)

            down = [downrelu, downconv]

            up = [uprelu, upconv, upnorm]

            model = down + up

        else:

            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,

                                        kernel_size=4, stride=2,

                                        padding=1, bias=use_bias)

            down = [downrelu, downconv, downnorm]

            up = [uprelu, upconv, upnorm]

            ##是否使用dropout

            if use_dropout:

                model = down + [submodule] + up + [nn.Dropout(0.5)]

            else:

                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):

        if self.outermost:#最外层直接输出

            return self.model(x)

        else:#添加跳层

            return torch.cat([x, self.model(x)], 1)

Where outer_nc is the number of outer channels, inner_nc is the number of inner channels, input_nc is the number of input channels, submodule is the previous submodule, outermost determines whether it is the outermost layer, innermost determines whether it is the innermost layer, and norm_layer is the normalization layer , user_dropout is whether to use dropout.

For the pix2pix model, the normalization layer used defaults to nn.BatchNorm2d, which is effectively equivalent to InstanceNorm when batch=1.

5. Discriminator Definition

Next, let's look at the definition of the discriminator. The discriminator is a classification model, but as we said earlier, its input is not the whole picture, but an image block, so the output of the discriminator is not a single number, but The probability map of multiple image blocks, which are finally added to obtain the complete probability, which is defined as follows:

PatchGAN is defined as follows

class NLayerDiscriminator(nn.Module):

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):

        super(NLayerDiscriminator, self).__init__()

        if type(norm_layer) == functools.partial:  ##判断归一化层类别,如果是BN则不需要bias

            use_bias = norm_layer.func == nn.InstanceNorm2d

        else:

            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4 ##卷积核大小

        padw = 1 ##填充大小

        ## 第一个卷积层

        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]

        nf_mult = 1

        nf_mult_prev = 1

        ## 中间2个卷积层

        for n in range(1, n_layers):  ##逐渐增加通道宽度,每次扩充为原来两倍

            nf_mult_prev = nf_mult

            nf_mult = min(2 ** n, 8)

            sequence += [

                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),

                norm_layer(ndf * nf_mult),

                nn.LeakyReLU(0.2, True)

            ]

        nf_mult_prev = nf_mult

        nf_mult = min(2 ** n_layers, 8)

        ## 第五个卷积层

        sequence += [

            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),

            norm_layer(ndf * nf_mult),

            nn.LeakyReLU(0.2, True)

        ]

        ## 输出单通道预测结果图

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 

        self.model = nn.Sequential(*sequence)

    def forward(self, input):

        return self.model(input)

where input_nc is the input graph channel, ndf is the last convolutional layer channel, n_layers is the number of discriminator layers, and norm_layer is the normalization layer type. As can be seen from the code, 5 layers of convolution are included by default, in which the size of the convolution kernel is 4 4, the stride=2 of the first 3 layers, the stride=1 of the last two convolution layers, and the total receptive field is 70 70 , which is also the reason why the block of 70*70 is actually taken as mentioned above. The input and output of each layer and the statistics of the receptive field are as follows:

9.png

6. Loss function definition

Next we look at the definition of the loss function.

class GANLoss(nn.Module):

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):

        #gan_model,损失类型,支持原始损失,lsgan,wgangp

        super(GANLoss, self).__init__()

        self.register_buffer('real_label', torch.tensor(target_real_label))

        self.register_buffer('fake_label', torch.tensor(target_fake_label))

        self.gan_mode = gan_mode

        if gan_mode == 'lsgan':

            self.loss = nn.MSELoss()

        elif gan_mode == 'vanilla':

            self.loss = nn.BCEWithLogitsLoss()

        elif gan_mode in ['wgangp']:

            self.loss = None

        else:

            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    #将标签转为与预测结果图同样大小

    def get_target_tensor(self, prediction, target_is_real):

        if target_is_real:

            target_tensor = self.real_label

        else:

            target_tensor = self.fake_label

        return target_tensor.expand_as(prediction)

    ##返回损失

    def __call__(self, prediction, target_is_real):

        if self.gan_mode in ['lsgan', 'vanilla']:

            target_tensor = self.get_target_tensor(prediction, target_is_real)

            loss = self.loss(prediction, target_tensor)

        elif self.gan_mode == 'wgangp':

            if target_is_real:

                loss = -prediction.mean()

            else:

                loss = prediction.mean()

        return loss

The above code implements the calculation of several common GAN adversarial losses.

7. Complete structure definition

After defining the discriminator and generator, let's look at the definition of the complete pix2pix model, as follows:
class Pix2PixModel(BaseModel):

##配置默认参数

def modify_commandline_options(parser, is_train=True):

    ##默认使用batchnorm,网络结构为unet_256,使用成对的(aligned)图片数据集

    parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')

    if is_train:

        parser.set_defaults(pool_size=0, gan_mode='vanilla')#使用经典GAN损失

        parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')#L1损失权重为100

def __init__(self, opt):

    BaseModel.__init__(self, opt)

    self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] ##损失

    self.visual_names = ['real_A', 'fake_B', 'real_B'] ##中间结果图

    if self.isTrain:

        self.model_names = ['G', 'D']

    else:  # during test time, only load G

        self.model_names = ['G']

    #生成器和判别器定义

    self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

    # 判别器定义,输入RGB图和生成器图的拼接

    if self.isTrain:

        self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

    if self.isTrain:

        #损失函数定义,GAN标准损失和L1重建损失

        self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)

        self.criterionL1 = torch.nn.L1Loss()

        # 优化器,使用Adam

        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

        self.optimizers.append(self.optimizer_G)

        self.optimizers.append(self.optimizer_D)

def set_input(self, input):

## 输入预处理,根据不同方向进行A,B的设置

    AtoB = self.opt.direction == 'AtoB'

    self.real_A = input['A' if AtoB else 'B'].to(self.device)

    self.real_B = input['B' if AtoB else 'A'].to(self.device)

    self.image_paths = input['A_paths' if AtoB else 'B_paths']

# 生成器前向传播

def forward(self):

    self.fake_B = self.netG(self.real_A)  #G(A)

# 判别器损失

 def backward_D(self):

    #假样本损失

    fake_AB = torch.cat((self.real_A, self.fake_B), 1) 

    pred_fake = self.netD(fake_AB.detach())

    self.loss_D_fake = self.criterionGAN(pred_fake, False)

    #真样本损失

    real_AB = torch.cat((self.real_A, self.real_B), 1)

    pred_real = self.netD(real_AB)

    self.loss_D_real = self.criterionGAN(pred_real, True)

    #真样本和假样本损失平均

    self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

    self.loss_D.backward()

# 生成器损失

def backward_G(self):

# GAN损失

    fake_AB = torch.cat((self.real_A, self.fake_B), 1)

    pred_fake = self.netD(fake_AB)

    self.loss_G_GAN = self.criterionGAN(pred_fake, True)

    #重建损失

    self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1

    #损失加权平均

    self.loss_G = self.loss_G_GAN + self.loss_G_L1

    self.loss_G.backward()

def optimize_parameters(self):

    self.forward()                   # 计算G(A)

    # 更新D

    self.set_requires_grad(self.netD, True)  #D

    self.optimizer_D.zero_grad()     #D梯度清零

    self.backward_D()                #计算 D梯度

    self.optimizer_D.step()          #更新D权重

    # 更新G

    self.set_requires_grad(self.netD, False)  #优化G时无须迭代D

    self.optimizer_G.zero_grad()        # G梯度清零

    self.backward_G()                   # 计算 G梯度

    self.optimizer_G.step()             #更新G权重

The above completes the interpretation of the core code in the project. Next, we train and test the model.

8. Dataset Preparation

First we prepare the color dataset A
10.png

Then we prepare the black and white dataset B
11.png

We choose the direction of training as B to A

8. Model training

Model training is to complete model definition, data loading, visualization and storage. The core code is as follows:

if __name__ == '__main__':

    opt = TrainOptions().parse()   #获取一些训练参数

    dataset = create_dataset(opt)  #创建数据集

    dataset_size = len(dataset)    #数据集大小

    print('The number of training images = %d' % dataset_size)

    model = create_model(opt)      #创建模型

    model.setup(opt)               #模型初始化

    visualizer = Visualizer(opt)   #可视化函数

    total_iters = 0                #迭代batch次数

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):

        epoch_iter = 0                  #当前epoch迭代batch数

        for i, data in enumerate(dataset):  #每一个epoch内层循环

            visualizer.reset()

            total_iters += opt.batch_size #总迭代batch数

            epoch_iter += opt.batch_size

            model.set_input(data)         #输入数据

            model.optimize_parameters()   #迭代更新

            if total_iters % opt.display_freq == 0:   #visdom可视化

                save_result = total_iters % opt.update_html_freq == 0

                model.compute_visuals()

                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:    #存储损失等信息

                losses = model.get_current_losses()

                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)

                if opt.display_id > 0:

                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:   #存储模型

                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))

                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'

                model.save_networks(save_suffix)

        if epoch % opt.save_epoch_freq == 0: #每隔opt.save_epoch_freq各epoch存储模型

            model.save_networks('latest')

            model.save_networks(epoch)

        model.update_learning_rate()#每一个epoch后更新学习率

其中的一些重要训练参数配置如下:

input_nc=1,表示生成器输入为1通道图像,即L通道。

output_nc=2,表示生成器输出为2通道图像,即AB通道。

ngf=64,表示生成器最后1个卷积层输出通道为64。

ndf=64,表示判别器最后1个卷积层输出通道为64。

n_layers_D=3,表示使用默认的PatchGAN,它相当于对70×70大小的图像块进行判别。

norm=batch,batch_size=1,表示使用批次标准化。

load_size=286,表示载入的图像尺寸。

crop_size=256,表示图像裁剪即训练尺寸。

9. Effect display

3.png

4.png

5.png

10. Project video display

Python based on CycleGAN & pix2pix AI coloring of black and white images (source code & deployment tutorial)

11. Project Integration

1.png

12. Complete source code & environment deployment video tutorial & dataset:

Baidu Bread can download the source code by searching for the title name

13. References

Guess you like

Origin blog.csdn.net/cheng2333333/article/details/126747993