【SHM】Semantic Human Matting抠图算法调试

前言:

2018年阿里的论文《Semantatic Human Matting》给出了抠图领域的一个新方法,可惜阿里并没有公布源码,而牛人在Github上对这个论文进行了复现,我也是依赖Github上的工程进行钻研,而在调试的过程中,发现有一些地方原作者并没有检验通过就上传,导致训练过程出错,这篇博客就是讲解如何调试通过Github上的Semantic_Human_Matting工程的训练以及测试的代码。

-------------------------------------------------------------------------------------------------------------------------

申明:

  • 写博客的初衷一是为了记录,二也是为后来人填坑——测试效果的好坏受算法结构、受数据集、受训练次数等因素的影响,留言板处不要因为你的结果表现不优良而无视博主无偿付出、甚至恶评相向,这样的白嫖党我劝你善良。

-------------------------------------------------------------------------------------------------------------------------

一、SHM网络简单讲解

通过下面Semantic_Human_Matting网络图开始讲解SHM的网络设计:
在这里插入图片描述

SHM的网络大致分为三个部分:

  1. T-Net网络部分:这部分的作用主要是预测生成trimap图。网络的输入是原图 + mask图;
  2. M-Net网络部分:这部分的作用主要是预测生成alpha图。网络的输入来源于三部分:第一个是原图(上图最左边的那张),第二个是原图对应的mask图(真正输入到网络中的mask图会被拆分成前景图 + 背景图两部分,也就是上图中的 F s F_s Fs B s B_s Bs),第三个是trimap图(真正输入到网络中的只要trimap图的不确定区域,也就是上图中的 U s U_s Us),预测得到上图中的 α r α_r αr
  3. Fusion Module这部分的作用主要是融合得到精准的alpha图。最后精准 α p α_p αp遮罩图的概率估计是: α p = F s + U s α r α_p = F_s + U_sα_r αp=Fs+Usαr

-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------

二、SHM数据集调整说明

2.1、工程下载,以及环境配置

Github上的Semantic_Human_Matting工程链接在此处此处此处,先下载解压;

根据工程主页上的说明,需要的是python3.5/3.6,torch>=0.4.0,以及opencv-python,我配置的机器环境是ubuntu16.04 + cuda10.0 + python3.6.12 + torch0.4.1 + opencv3.4.3。Windows机器我好像配置过,好像没通过(记不大清楚了,有兴趣的去试一试)

-------------------------------------------------------------------------------------------------------------------------

2.2、下载数据集

2.2.1、最头痛的就是数据集的建立,因为建立大型数据集耗时耗力。所幸工程主页里作者给出了他找到的数据集,在这里对作者及爱分割公司表示感谢,数据集的链接在此处此处此处,密码是:dzsn,下载解压。

2.2.2、解压后可以看到其下主要包含两个文件夹:
在这里插入图片描述

  • clip_img文件夹:其下都是原图;
  • matting文件夹:其下都是原图对应的mask图,但是需要处理一下;

注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。
注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。
注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。

2.2.3、在工程data目录下新建mattingclip_img文件夹,再将数据集mattingclip_img文件夹下的挑选任意一个相同文件夹对应放入工程目录中,隶属关系如下:
在这里插入图片描述

-------------------------------------------------------------------------------------------------------------------------

2.3、matting图生成对应的mask图:

先在data文件夹下新建zcm_matting_get_mask.py文件,代码如下,然后执行这个py文件,完成后可以在data目录下看到生成了一个新的mask文件夹,其下存储着黑白底的mask图。

import os
import cv2


matting_path = "matting/"
mask_path = "mask/"

# test
# for mask_name in os.listdir(matting_path):
#     in_image = cv2.imread(matting_path + mask_name, cv2.IMREAD_UNCHANGED)
#     alpha = in_image[:,:,3]
#     cv2.imwrite(mask_path + mask_name, alpha)

for name_0 in os.listdir(matting_path):
    if not os.path.exists(mask_path + "/" + name_0):
        os.makedirs(mask_path + "/" + name_0)
    for name_1 in os.listdir(matting_path + "/" + name_0):
        if not os.path.exists(mask_path + name_0 + "/" + name_1):
            os.mkdir(mask_path + name_0 + "/" + name_1)
        for name_2 in os.listdir(matting_path + "/" + name_0 + "/" + name_1):

            pic_input_path = matting_path + "/" + name_0 + "/" + name_1 + "/" + name_2
            pic_output_path = mask_path + "/" + name_0 + "/" + name_1 + "/" + name_2
            print("pic_input_path=", pic_input_path)

            in_image = cv2.imread(pic_input_path, cv2.IMREAD_UNCHANGED)
            alpha = in_image[:, :, 3]
            cv2.imwrite(pic_output_path, alpha)

-------------------------------------------------------------------------------------------------------------------------

2.4、生成训练数据的TXT目录:

先在data文件夹下新建zcm_get_train_txt.py文件,代码如下,然后执行这个py文件,完成后可以在data目录下看到生成了一个新的train.txt文件,打开里面存储图片的路径。
在这里插入图片描述

import os

pic_path = "matting/"

with open("train.txt", "w", encoding="UTF-8") as ff:
    for name_0 in os.listdir(pic_path):
        for name_1 in os.listdir(pic_path + "/" + name_0):
            for name_2 in os.listdir(pic_path + "/" + name_0 + "/" + name_1):
                pic_input_path = name_0 + "/" + name_1 + "/" + name_2
                ff.write(pic_input_path + "\n")
    ff.close()
print("well done____________!")

-------------------------------------------------------------------------------------------------------------------------

2.5、由mask图生成trimap图:

2.5.1:像下面一样注释掉gen_trimap.py第36/42/48行的断言语句;

# assert(cnt1 == cnt2 + cnt3)

2.5.2:在gen_trimap.py第四行添加语句,引入os库;

import os

2.5.3:在gen_trimap.py第64行后,添加如下代码;

trimap_name_1 = trimap_name.split("/")[:-1]
trimap_path = "/".join(trimap_name_1)
if not os.path.exists(trimap_path):
    os.makedirs(trimap_path)

在这里插入图片描述

2.5.4:执行sh gen_trimap.sh脚本,就可以生成得到trimap文件夹,及其其下的trimap图片;

-------------------------------------------------------------------------------------------------------------------------

2.6、生成alpha图:

说明:这里给出两种生成alpha图的方法:

  1. 用工程自带的knn_matting.sh脚本生成alpha图;
  2. 直接拷贝mask文件夹,将mask图作为精确的alpha图注入训练;

第一种方法我在简单测试中使用过,该方法非常非常非常的耗时间,而且用该方法处理爱分割公司提供的数据集得到了alpha图,将其注入训练后,对最后的预测的准确率的影响并不大;有兴趣的朋友可以对knn_matting继续改进,将时间效率提高;

我也阐述使用第二种方法的依据:因为爱分割公司的数据集的mask图是精确的,是直接通过matting文件夹生成的。爱分割公司在提供数据集的时候,mask图就是他们人工扣出来的。而knn_matting.sh脚本存在的意义,是对于正常情况下,我们如使用faster-RCNN,DeepLab这样的分割算法得到的mask图是不精准的,才需要使用knn_matting算法处理边界,得到精准的alpha图。

所以这一步,在data文件夹下新建alpha文件夹后,再执行下面复制语句,将mask文件夹下所有文件复制到alpha文件夹;

cp -r mask/* alpha/

至此,数据集准备工作全部做完。

-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------

三、训练细节调整说明

3.1、写入训练code:

先在Semantic_Human_Matting工程目录下,新建train_code.txt文件,写入如下指令:

# # T-Net训练指令
python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=200 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=1e-5 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='pre_train_t_net'

# # M-Net训练指令
python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=400 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=5e-6 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='end_to_end'

第一段是T-Net训练代码,第二段是M-Net训练代码

-------------------------------------------------------------------------------------------------------------------------

3.2、修改train.py文件:

train.py文件第29行后添加一条语句,用来指示GPU的使用情况

parser.add_argument('--gpus', default='0,1,2,3', help='gpus number')

在这里插入图片描述

-------------------------------------------------------------------------------------------------------------------------

3.3、修改dataset.py文件:

3.3.1:用如下语句替换dataset.py文件第17/18/19行

image_name = os.path.join(data_dir, 'clip_img', file_name['image'].replace("matting", "clip").replace("png", "jpg"))
trimap_name = os.path.join(data_dir, 'trimap', file_name['trimap'].replace("clip", "matting"))
alpha_name = os.path.join(data_dir, 'alpha', file_name['alpha'].replace("clip", "matting"))

在这里插入图片描述

3.3.2:用如下语句替换dataset.py文件第101/102/103行:

trimap[trimap == 0] = 0
trimap[trimap >= 250] = 2
trimap[np.where(~((trimap == 0) | (trimap == 2)))] = 1

在这里插入图片描述
这里是整个代码中错误最隐蔽的一个,当初也是花了我很长时间才搞定。我解释一下为什么这样做:我们知道trimap图是三色图,但是它的“三色”并不像上图中0/128/255只有这三色,它是在[0, 255]这个区间范围内。所以新改的代码,将这“三色”用区间区分,作为三种不同的label传入训练。

-------------------------------------------------------------------------------------------------------------------------

3.4、开启T-Net训练:

运行train_code.txt第一行代码,开启T-Net训练,如果你报内存不足的错误,就适当调小patch_size,nThreads,train_batch的数值;

python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=200 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=1e-5 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='pre_train_t_net'

下图是我T-Net训练过程的loss变化,你也可以为得到更好的结果而增大nEpochs训练轮数;
在这里插入图片描述

-------------------------------------------------------------------------------------------------------------------------

3.5、开启M-Net训练:

运行train_code.txt第二行代码,开启M-Net微调训练

python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=400 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=5e-6 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='end_to_end'

下图是我M-Net训练过程的loss变化,你也可以为得到更好的结果而增大nEpochs训练轮数;
在这里插入图片描述

-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------

四、测试细节调整说明

4.1:新建test_camera_used.py文件

写入如下代码,代码与test_camera.py文件很相似,只是改了一部分需求,让过程更简洁;

'''
    test camera 

Author: Zhengwei Li
Date  : 2018/12/28
'''
import time
import cv2
import torch 
import argparse
import numpy as np
import os 
import torch.nn.functional as F
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'

parser = argparse.ArgumentParser(description='human matting')
parser.add_argument('--model', default='./ckpt/human_matting/model/model_obj.pth', help='preTrained model')
parser.add_argument('--size', type=int, default=320, help='input size')
parser.add_argument('--without_gpu', action='store_true', default=False, help='no use gpu')

args = parser.parse_args()

torch.set_grad_enabled(False)


#################################
#----------------
if args.without_gpu:
    print("use CPU !")
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        n_gpu = torch.cuda.device_count()
        print("----------------------------------------------------------")
        print("|       use GPU !      ||   Available GPU number is {} !  |".format(n_gpu))
        print("----------------------------------------------------------")
        device = torch.device('cuda: 0, 1, 2, 3')

#################################
#---------------
def load_model(args):
    print('Loading model from {}...'.format(args.model))
    if args.without_gpu:
        myModel = torch.load(args.model, map_location=lambda storage, loc: storage)
    else:
        myModel = torch.load(args.model)

    myModel.eval()
    myModel.to(device)
    # myModel.cuda()

    return myModel

def seg_process(args, image, net):

    # opencv
    origin_h, origin_w, c = image.shape
    image_resize = cv2.resize(image, (args.size,args.size), interpolation=cv2.INTER_CUBIC)
    image_resize = (image_resize - (104., 112., 121.,)) / 255.0

    tensor_4D = torch.FloatTensor(1, 3, args.size, args.size)

    tensor_4D[0,:,:,:] = torch.FloatTensor(image_resize.transpose(2,0,1))
    inputs = tensor_4D.to(device)

    trimap, alpha = net(inputs)

    trimap_np = trimap[0, 0, :, :].cpu().data.numpy()
    trimap_np = cv2.resize(trimap_np, (origin_w, origin_h), interpolation=cv2.INTER_CUBIC)
    mask_result = np.multiply(trimap_np[..., np.newaxis], image)

    trimap_1 = mask_result.copy()
    mask_result[trimap_1 < 10] = 255
    mask_result[trimap_1 >= 10] = 0
    cv2.imwrite("mask_result.png", mask_result)

    if args.without_gpu:
        alpha_np = alpha[0,0,:,:].data.numpy()
    else:
        alpha_np = alpha[0,0,:,:].cpu().data.numpy()


    alpha_np = cv2.resize(alpha_np, (origin_w, origin_h), interpolation=cv2.INTER_CUBIC)

    fg = np.multiply(alpha_np[..., np.newaxis], image)

    # cv2.imwrite("fg.png", fg)

    # bg = image
    # bg_gray = np.multiply(1 - alpha_np[..., np.newaxis], image)
    # bg_gray = cv2.cvtColor(bg_gray, cv2.COLOR_BGR2GRAY)
    # # print("bg_gray=", bg_gray)
    # bg[:,:,0] = bg_gray
    # bg[:,:,1] = bg_gray
    # bg[:,:,2] = bg_gray
    #
    # # fg[fg<=0] = 0
    # # fg[fg>255] = 255
    # # fg = fg.astype(np.uint8)
    # # out = cv2.addWeighted(fg, 0.7, bg, 0.3, 0)
    #
    # # out = fg + bg
    # # out[out<0] = 0
    # # out[out>255] = 255
    # # out = out.astype(np.uint8)
    #
    # out = fg.copy()
    # out[out<10] = 0
    # out[out>=10] = 255
    # out = out.astype(np.uint8)

    return fg, mask_result


def camera_seg(args, net):

    # videoCapture = cv2.VideoCapture(0)
    #
    # while(1):
    #     # get a frame
    #     ret, frame = videoCapture.read()
    #     frame = cv2.flip(frame,1)
    #     frame_seg = seg_process(args, frame, net)
    #
    #
    #     # show a frame
    #     cv2.imshow("capture", frame_seg)
    #
    #     if cv2.waitKey(1) & 0xFF == ord('q'):
    #         break
    # videoCapture.release()

    test_pic_path = "test_pic/"
    output_path = "result/"
    if not os.path.exists(output_path):
        os.mkdir(output_path)

    time_0 = time.time()
    for name_ in os.listdir(test_pic_path):
        frame = cv2.imread(test_pic_path + name_)
        fg, mask_result = seg_process(args, frame, net)
        print("SUCCESS_____!", test_pic_path + name_)
        cv2.imwrite(output_path + name_.split(".")[0] + "_fg.jpg", fg)
        cv2.imwrite(output_path + name_, mask_result)
    print("time_all = ", time.time() - time_0)


def main(args):
    time_1 = time.time()
    myModel = load_model(args)
    print("lodding_model_time = ", time.time() - time_1)
    camera_seg(args, myModel)


if __name__ == "__main__":
    main(args)

4.2:测试过程

在主目录下新建test_pic文件夹,将测试所用的pic图片存入其中后,运行test_camera_used.py文件,就能在result文件夹下得到预测的结果图。

-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------

五、最后的说明:

  1. 爱分割公司提供的数据集中,某一个目录中有一个没用的隐藏文件,如果不删除的话,数据准备过程、训练过程会报错——但是我忘了具体在哪个文件夹…
  2. 我训练了一个较好的model,所用的设备是具有4个GTX2080的显卡服务器跑了将近10天,用上了爱分割公司全部数据集 + 自建的一些数据集,因为公司的保密协议,我不能公布这个model,只展示我测试的结果。左边是预测生成图,右边是原图;
  3. 有问题欢迎留言垂询;

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/zzZ_CMing/article/details/109490676