华为开源自研AI框架昇思MindSpore应用案例:Colorization自动着色

自动着色算法之Colorization
当桃乐丝在1939年的电影《绿野仙踪》中走进奥兹国时,从黑白到鲜艳的色彩的转变使它成为电影史上最令人叹为观止的时刻之一。毫无疑问,颜色是一种有效的表达工具,但它们通常是有代价的。在制作现代动画电影和漫画时,图像着色是最费力和昂贵的阶段之一。自动着色过程可以帮助减少制作漫画或动画电影所需的成本和时间

模型简介
Colorization算法是来自加里福利亚大学的一项研究,采用的是CNN的结构。该算法可以实现灰度图像的自动着色,由Richard
Zhang等人在论文Colorful Image
Colorization中提出,并发表在2016年的ECCV会议中。该模型由8个conv层组成,每个conv层由2个或3个重复的卷积层和ReLU层组成,后面跟着一个BatchNorm层。网络中不包含池化层。

网络特点

  1. 设计了一个合适的损失函数来处理着色问题中的多模不确定性,维持了颜色的多样性。
  2. 将图像着色任务转化为一个自监督表达学习的任务。
  3. 在一些基准模型上获得了最好的效果。

完整的样例代码:Colorization.ipynb

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

在这里插入图片描述

在这里插入图片描述

一、环境准备

1.进入ModelArts官网

云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网

在这里插入图片描述

选择下方CodeLab立即体验

在这里插入图片描述

等待环境搭建完成

在这里插入图片描述

2.使用ModelArts体验实例

进入昇思MindSpore官网,点击上方的安装

在这里插入图片描述

获取安装命令

在这里插入图片描述

在ModelArts中切换规格

在这里插入图片描述

打开一个Terminal,输入安装命令

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge

在这里插入图片描述

再点击侧边栏中的Clone a Repository,输入

https://github.com/mindspore-courses/applications.git

在这里插入图片描述

二、数据处理

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore Vision套件。

数据准备

本案例使用ImageNet数据集作为训练集和测试集。请在官网下载。训练集中包含1000个类别,总计大约120万张图片,测试集中包含5万图片。

解压后的数据集目录结构如下:

.dataset/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
└── val/

训练集可视化

import os
import argparse
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import mindspore
from src.process_datasets.data_generator import ColorizationDataset


#加载参数
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='./dataset/train', help='path to dataset')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--shuffle', type=bool, default=True)
args = parser.parse_args(args=[])
plt.figure()

#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle, args.num_parallel_workers)
data = dataset.run()
show_data = next(data.create_tuple_iterator())
show_images_original, _ = show_data
show_images_original = show_images_original.asnumpy()
#循环处理
for i in range(1, 5):
    plt.subplot(1, 4, i)
    temp = show_images_original[i-1]
    temp = np.clip(temp, 0, 1)
    plt.imshow(temp)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)

在这里插入图片描述

构建网络

处理完数据后进行网络的搭建,Colorization的网络结构较为简单,采用CNN的网络结构。具体结构如下图所示
在这里插入图片描述
网络的详细配置为:
在这里插入图片描述
其中X输出的空间分辨率,C输出的通道数;S计算步幅,大于1表示卷积后下采样,小于1表示卷积前上采样;D内核扩张;Sa在所有前一层的累积步数(积于前一层的所有步数);相对于输入的层的有效膨胀(层膨胀乘以累积步幅);BN层后是否使用BatchNorm层;L表示是否施加了1x1的卷积和交叉熵损失层。

损失函数

在这里插入图片描述
分类再平衡
在这里插入图片描述

分类概率到点估计

在这里插入图片描述

class NetLoss(nn.Cell):
    """连接网络和损失"""
    def __init__(self, net):
        super(NetLoss, self).__init__(auto_prefix=True)
        self.net = net
        self.loss = nn.CrossEntropyLoss(reduction='none')

    def construct(self, images, targets, boost, mask):
        """ build network """
        outputs = self.net(images)
        boost_nongray = boost * mask
        squeeze = mindspore.ops.Squeeze(1)
        boost_nongray = squeeze(boost_nongray)
        result = self.loss(outputs, targets)
        result_loss = (result * boost_nongray).mean()
        return result_loss

在这里插入图片描述

三、模型实现

MindSpore要求将损失函数、优化器等操作也看做nn.Cell的子类,所以我们可以自定义Color类,将网络和loss连接起来。

class ColorModel(nn.Cell):
    """定义Colorization网络"""

    def __init__(self, my_train_one_step_cell_for_net):
        super(ColorModel, self).__init__(auto_prefix=True)
        self.my_train_one_step_cell_for_net = my_train_one_step_cell_for_net

    def construct(self, result, targets, boost, mask):
        loss = self.my_train_one_step_cell_for_net(result, targets, boost,
                                                   mask)
        return loss

在这里插入图片描述

算法流程

在这里插入图片描述

模型训练

实例化损失函数,优化器,使用Model接口编译网络,开始训练。

import argparse
import os
from tqdm import tqdm

import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import ops
import numpy as np
import matplotlib.pyplot as plt
from src.utils.utils import PriorBoostLayer, NNEncLayer, NonGrayMaskLayer, decode

from src.model.model import ColorizationModel
from src.model.colormodel import ColorModel
from src.process_datasets.data_generator import ColorizationDataset
from src.losses.loss import NetLoss
import warnings

warnings.filterwarnings('ignore')
#加载参数

parser = argparse.ArgumentParser()
parser.add_argument('--device_target',
                    default='GPU',
                    choices=['CPU', 'GPU', 'Ascend'],
                    type=str)
parser.add_argument('--device_id', default=1, type=int)
parser.add_argument('--image_dir',
                    type=str,
                    default='./dataset/train',
                    help='path to dataset')
parser.add_argument('--checkpoint_dir',
                    type=str,
                    default='./checkpoints',
                    help='path for saving trained model')
parser.add_argument('--test_dirs',
                    type=str,
                    default='./images',
                    help='path for saving trained model')
parser.add_argument('--resource', type=str, default='./src/resources/')
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--num_epochs', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.5e-4)
parser.add_argument('--save_step',
                    type=int,
                    default=200,
                    help='step size for saving trained models')
args = parser.parse_args(args=[])

if context.get_context('device_id') != args.device_id:
    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)

encode_layer = NNEncLayer(args)
boost_layer = PriorBoostLayer(args)
non_gray_mask = NonGrayMaskLayer()

#网络实例化
net = ColorizationModel()

#设置优化器
net_args = nn.Adam(net.trainable_params(), learning_rate=args.learning_rate)

#实例化NetLoss
net_with_criterion = NetLoss(net)

#实例化TrainOneStepWithLossScaleCell
scale_sense = nn.FixedLossScaleUpdateCell(1)
myTrainOneStepCellForNet = nn.TrainOneStepWithLossScaleCell(
    net_with_criterion, net_args, scale_sense=scale_sense)
colormodel = ColorModel(myTrainOneStepCellForNet)
colormodel.set_train()

#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle,
                              args.num_parallel_workers)
data = dataset.run().create_tuple_iterator()

for epoch in range(args.num_epochs):
    iters = 0

    #为每轮训练读入数据
    for images, img_ab in tqdm(data):
        images = ops.expand_dims(images, 1)
        encode, max_encode = encode_layer.forward(img_ab)
        targets = mindspore.Tensor(max_encode, dtype=mindspore.int32)
        boost = mindspore.Tensor(boost_layer.forward(encode),
                                 dtype=mindspore.float32)
        mask = mindspore.Tensor(non_gray_mask.forward(img_ab),
                                dtype=mindspore.float32)
        net_loss = colormodel(images, targets, boost, mask)
        #输出训练数据
        print('[%d/%d]\tLoss_net:: %.4f' % (epoch + 1, args.num_epochs, net_loss[0]))
        #中间保存训练结果
        if iters % args.save_step == 0:
            if not os.path.exists(args.checkpoint_dir):
                os.makedirs(args.checkpoint_dir)
            mindspore.save_checkpoint(
                net,
                os.path.join(args.checkpoint_dir, 'net' + str(epoch + 1) + '_' +
                             str(iters) + '.ckpt'))
            img_ab_313 = net(images)
            out_max = np.argmax(img_ab_313[0].asnumpy(), axis=0)
            color_img = decode(images, img_ab_313, args.resource)
            if not os.path.exists(args.test_dirs):
                os.makedirs(args.test_dirs)
            plt.imsave(
                args.test_dirs + '/' + str(epoch + 1) + '_' + str(iters) +
                '%s_infer.png', color_img)
        iters = iters + 1

在这里插入图片描述

在这里插入图片描述

模型推理

运行下面代码,将一张灰度图像输入到网络中,即可生成具有合理色彩的图像。

import argparse
import os

import matplotlib.pyplot as plt
import mindspore
import numpy as np
from mindspore import (context, load_checkpoint, load_param_into_net, ops)
from mindspore.train.model import Model
from tqdm import tqdm

from src.model.model import ColorizationModel
from src.process_datasets.data_generator import ColorizationDataset
from src.utils.utils import decode


parser = argparse.ArgumentParser()
parser.add_argument('--img_path', type=str, default='./dataset/val')
parser.add_argument('--ckpt_path', type=str, default='./checkpoints/net44_1600.ckpt')
parser.add_argument('--resource', type=str, default='./src/resources/')
parser.add_argument('--device_target', default='GPU', choices=['CPU', 'GPU', 'Ascend'], type=str)
parser.add_argument('--device_id', default=1, type=int)
parser.add_argument('--infer_dirs', default='./dataset/output', type=str)
args = parser.parse_args(args=[])


mindspore.context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)

#实例化网络
net = ColorizationModel()

#加载参数
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net, param_dict)
colorizer = Model(net)
dataset = ColorizationDataset(args.img_path, 1, prob=0)
data = dataset.run().create_tuple_iterator()
iters = 0

if not os.path.exists(args.infer_dirs):
    os.makedirs(args.infer_dirs)

#循环处理图像
for images, img_ab in tqdm(data):
    images = ops.expand_dims(images, 1)
    img_ab_313 = colorizer.predict(images)
    out_max = np.argmax(img_ab_313[0].asnumpy(), axis=0)
    color_img = decode(images, img_ab_313, args.resource)
    plt.imsave(args.infer_dirs+'/'+str(iters)+'_infer.png', color_img)
    iters = iters + 1

在这里插入图片描述
在这里插入图片描述

总结

本案例对Colorful Image
Colorization文中提出的模型进行了详细的解释,向读者完整地展现了该算法的流程,分析了Colorization在着色方面的优势和存在的不足。如需查看详细代码,可参考MindSpore
Vision套件。

猜你喜欢

转载自blog.csdn.net/qq_46207024/article/details/129956065