Huawei's open-source self-developed AI framework Shengsi MindSpore application case: Colorization automatic coloring

Colorization
When Dorothy walks into Oz in 1939's The Wizard of Oz, the transition from black and white to vibrant color makes it one of the most breathtaking moments in film history. There is no doubt that colors are an effective tool of expression, but they often come at a price. Image colorization is one of the most laborious and expensive stages when producing modern animated films and manga. Automated coloring process can help reduce the cost and time needed to create a manga or animated film

Model Introduction
The Colorization algorithm is a research from the University of California, using the structure of CNN. This algorithm can realize automatic coloring of grayscale images, proposed by Richard
Zhang et al. in the paper Colorful Image
Colorization, and published in the 2016 ECCV conference. The model consists of 8 conv layers, each conv layer consists of 2 or 3 repeated convolutional and ReLU layers, followed by a BatchNorm layer. The network does not contain pooling layers.

Network Features

  1. A suitable loss function is designed to handle the multimodal uncertainty in the coloring problem, maintaining the diversity of colors.
  2. Transform the image colorization task into a self-supervised representation learning task.
  3. The best results are obtained on some benchmark models.

Complete sample code: Colorization.ipynb

If you are interested in MindSpore, you can follow the Shengsi MindSpore community

insert image description here

insert image description here

1. Environmental preparation

1. Enter ModelArts official website

The cloud platform helps users quickly create and deploy models, and manage full-cycle AI workflows. Select the following cloud platform to start using Shengsi MindSpore, get the installation command , install MindSpore2.0.0-alpha version, and enter the ModelArts official website in the Shengsi tutorial

insert image description here

Choose CodeLab below to experience it immediately

insert image description here

Wait for the environment to be built

insert image description here

2. Use ModelArts to experience examples

Enter Shengsi MindSpore official website , click on the installation above

insert image description here

get install command

insert image description here

Toggle Specifications in ModelArts

insert image description here

Open a Terminal and enter the installation command

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge

insert image description here

Then click Clone a Repository in the sidebar and enter

https://github.com/mindspore-courses/applications.git

insert image description here

2. Data processing

Before starting the experiment, please ensure that the Python environment and the MindSpore Vision suite have been installed locally.

data preparation

This case uses the ImageNet dataset as the training set and test set. Please download from the official website . The training set contains 1,000 categories, totaling approximately 1.2 million images, and the test set contains 50,000 images.

The directory structure of the decompressed dataset is as follows:

.dataset/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
└── val/

Training set visualization

import os
import argparse
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import mindspore
from src.process_datasets.data_generator import ColorizationDataset


#加载参数
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='./dataset/train', help='path to dataset')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--shuffle', type=bool, default=True)
args = parser.parse_args(args=[])
plt.figure()

#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle, args.num_parallel_workers)
data = dataset.run()
show_data = next(data.create_tuple_iterator())
show_images_original, _ = show_data
show_images_original = show_images_original.asnumpy()
#循环处理
for i in range(1, 5):
    plt.subplot(1, 4, i)
    temp = show_images_original[i-1]
    temp = np.clip(temp, 0, 1)
    plt.imshow(temp)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)

insert image description here

build network

After the data is processed, the network is built. The network structure of Colorization is relatively simple, and the network structure of CNN is adopted. The specific structure is shown in the figure below
insert image description here
. The detailed configuration of the network is:
insert image description here
the spatial resolution of X output, the number of channels output by C; S calculates the stride, greater than 1 means downsampling after convolution, and less than 1 means upsampling before convolution; D Kernel expansion; the cumulative number of steps of Sa in all previous layers (all the steps accumulated in the previous layer); the effective expansion relative to the input layer (layer expansion multiplied by the cumulative stride); whether to use the BatchNorm layer after the BN layer ; L indicates whether a 1x1 convolution and cross-entropy loss layer is applied.

loss function

insert image description here
Category rebalancing
insert image description here

Classification Probability to Point Estimation

insert image description here

class NetLoss(nn.Cell):
    """连接网络和损失"""
    def __init__(self, net):
        super(NetLoss, self).__init__(auto_prefix=True)
        self.net = net
        self.loss = nn.CrossEntropyLoss(reduction='none')

    def construct(self, images, targets, boost, mask):
        """ build network """
        outputs = self.net(images)
        boost_nongray = boost * mask
        squeeze = mindspore.ops.Squeeze(1)
        boost_nongray = squeeze(boost_nongray)
        result = self.loss(outputs, targets)
        result_loss = (result * boost_nongray).mean()
        return result_loss

insert image description here

3. Model realization

MindSpore requires operations such as loss functions and optimizers to be regarded as subclasses of nn.Cell, so we can customize the Color class to connect the network and loss.

class ColorModel(nn.Cell):
    """定义Colorization网络"""

    def __init__(self, my_train_one_step_cell_for_net):
        super(ColorModel, self).__init__(auto_prefix=True)
        self.my_train_one_step_cell_for_net = my_train_one_step_cell_for_net

    def construct(self, result, targets, boost, mask):
        loss = self.my_train_one_step_cell_for_net(result, targets, boost,
                                                   mask)
        return loss

insert image description here

Algorithm process

insert image description here

model training

Instantiate the loss function, optimizer, compile the network using the Model interface, and start training.

import argparse
import os
from tqdm import tqdm

import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import ops
import numpy as np
import matplotlib.pyplot as plt
from src.utils.utils import PriorBoostLayer, NNEncLayer, NonGrayMaskLayer, decode

from src.model.model import ColorizationModel
from src.model.colormodel import ColorModel
from src.process_datasets.data_generator import ColorizationDataset
from src.losses.loss import NetLoss
import warnings

warnings.filterwarnings('ignore')
#加载参数

parser = argparse.ArgumentParser()
parser.add_argument('--device_target',
                    default='GPU',
                    choices=['CPU', 'GPU', 'Ascend'],
                    type=str)
parser.add_argument('--device_id', default=1, type=int)
parser.add_argument('--image_dir',
                    type=str,
                    default='./dataset/train',
                    help='path to dataset')
parser.add_argument('--checkpoint_dir',
                    type=str,
                    default='./checkpoints',
                    help='path for saving trained model')
parser.add_argument('--test_dirs',
                    type=str,
                    default='./images',
                    help='path for saving trained model')
parser.add_argument('--resource', type=str, default='./src/resources/')
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--num_epochs', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.5e-4)
parser.add_argument('--save_step',
                    type=int,
                    default=200,
                    help='step size for saving trained models')
args = parser.parse_args(args=[])

if context.get_context('device_id') != args.device_id:
    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)

encode_layer = NNEncLayer(args)
boost_layer = PriorBoostLayer(args)
non_gray_mask = NonGrayMaskLayer()

#网络实例化
net = ColorizationModel()

#设置优化器
net_args = nn.Adam(net.trainable_params(), learning_rate=args.learning_rate)

#实例化NetLoss
net_with_criterion = NetLoss(net)

#实例化TrainOneStepWithLossScaleCell
scale_sense = nn.FixedLossScaleUpdateCell(1)
myTrainOneStepCellForNet = nn.TrainOneStepWithLossScaleCell(
    net_with_criterion, net_args, scale_sense=scale_sense)
colormodel = ColorModel(myTrainOneStepCellForNet)
colormodel.set_train()

#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle,
                              args.num_parallel_workers)
data = dataset.run().create_tuple_iterator()

for epoch in range(args.num_epochs):
    iters = 0

    #为每轮训练读入数据
    for images, img_ab in tqdm(data):
        images = ops.expand_dims(images, 1)
        encode, max_encode = encode_layer.forward(img_ab)
        targets = mindspore.Tensor(max_encode, dtype=mindspore.int32)
        boost = mindspore.Tensor(boost_layer.forward(encode),
                                 dtype=mindspore.float32)
        mask = mindspore.Tensor(non_gray_mask.forward(img_ab),
                                dtype=mindspore.float32)
        net_loss = colormodel(images, targets, boost, mask)
        #输出训练数据
        print('[%d/%d]\tLoss_net:: %.4f' % (epoch + 1, args.num_epochs, net_loss[0]))
        #中间保存训练结果
        if iters % args.save_step == 0:
            if not os.path.exists(args.checkpoint_dir):
                os.makedirs(args.checkpoint_dir)
            mindspore.save_checkpoint(
                net,
                os.path.join(args.checkpoint_dir, 'net' + str(epoch + 1) + '_' +
                             str(iters) + '.ckpt'))
            img_ab_313 = net(images)
            out_max = np.argmax(img_ab_313[0].asnumpy(), axis=0)
            color_img = decode(images, img_ab_313, args.resource)
            if not os.path.exists(args.test_dirs):
                os.makedirs(args.test_dirs)
            plt.imsave(
                args.test_dirs + '/' + str(epoch + 1) + '_' + str(iters) +
                '%s_infer.png', color_img)
        iters = iters + 1

insert image description here

insert image description here

model reasoning

Run the code below to feed a grayscale image into the network and generate an image with reasonable colors.

import argparse
import os

import matplotlib.pyplot as plt
import mindspore
import numpy as np
from mindspore import (context, load_checkpoint, load_param_into_net, ops)
from mindspore.train.model import Model
from tqdm import tqdm

from src.model.model import ColorizationModel
from src.process_datasets.data_generator import ColorizationDataset
from src.utils.utils import decode


parser = argparse.ArgumentParser()
parser.add_argument('--img_path', type=str, default='./dataset/val')
parser.add_argument('--ckpt_path', type=str, default='./checkpoints/net44_1600.ckpt')
parser.add_argument('--resource', type=str, default='./src/resources/')
parser.add_argument('--device_target', default='GPU', choices=['CPU', 'GPU', 'Ascend'], type=str)
parser.add_argument('--device_id', default=1, type=int)
parser.add_argument('--infer_dirs', default='./dataset/output', type=str)
args = parser.parse_args(args=[])


mindspore.context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)

#实例化网络
net = ColorizationModel()

#加载参数
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net, param_dict)
colorizer = Model(net)
dataset = ColorizationDataset(args.img_path, 1, prob=0)
data = dataset.run().create_tuple_iterator()
iters = 0

if not os.path.exists(args.infer_dirs):
    os.makedirs(args.infer_dirs)

#循环处理图像
for images, img_ab in tqdm(data):
    images = ops.expand_dims(images, 1)
    img_ab_313 = colorizer.predict(images)
    out_max = np.argmax(img_ab_313[0].asnumpy(), axis=0)
    color_img = decode(images, img_ab_313, args.resource)
    plt.imsave(args.infer_dirs+'/'+str(iters)+'_infer.png', color_img)
    iters = iters + 1

insert image description here
insert image description here

Summarize

This case
provides a detailed explanation of the model proposed in the Colorful Image Colorization article, fully shows the process of the algorithm to readers, and analyzes the advantages and disadvantages of Colorization in coloring. For detailed codes, refer to the MindSpore
Vision suite.

Guess you like

Origin blog.csdn.net/qq_46207024/article/details/129956065