移动端unet人像分割模型--1

  个人对移动端神经网络开发一直饶有兴致。去年腾讯开源了NCNN框架之后,一直都在关注。近期成功利用别人训练好的mtcnn和mobilefacenet模型制作了一个ios版本人脸识别swift版本demo。希望maskrcnn移植到ncnn,在手机端实现一些有趣的应用。因为unet模型比较简单,干脆就从这个入手。

  基本的网络基于keras版本: https://github.com/TianzhongSong/Person-Segmentation-Keras

  不过keras没办法直接转成ncnn模型,研究过通过onnx模型做中间跳板,采用了一些开源的转换工具,也是一堆问题。NCNN支持几个神经网络训练框架:caffe/mxnet/pytorch,在ncnn的github有一篇issue里nihui推荐采用mxnet,因此mxnet也成为了我的首选。

  利用Person-Segmentation-Keras项目的数据集,同时基于https://github.com/milesial/Pytorch-UNet/tree/master/unet这个项目捣鼓了几段代码。训练完成,用来测试ncnn转换基本可用。

  转换过程发现许多问题,一个是调用ncnn extract会crash,经过调查,发现mxnet2ncnn工具也有bug,blob个数算错,其次是input层one_blob_only标志我的理解应该是false,不知道什么原因转换过来的模型这边是true,导致forward_layer函数里面bottoms变量访问异常。后来一层层extract出来打印输出的channel/width/height调查后又发现,我把unet.py里的name为pool5写成了pool4(文章中的code已经纠正),可能前面的crash跟这个致命错误有关系也说不定。只好重新训练模型,几个小时漫长等待,剩下部分下周再写。部分代码已经更新,请参考: https://github.com/xuduo35/unet_mxnet2ncnn

unetdataiter.py


#!/usr/bin/env python
# coding=utf8

import os
import sys
import random
import cv2
import mxnet as mx
import numpy as np
from mxnet.io import DataIter, DataBatch

sys.path.append('../')

def get_batch(items, root_path, nClasses, height, width):
    x = []
    y = []
    for item in items:
        image_path = root_path + item.split(' ')[0]
        label_path = root_path + item.split(' ')[-1].strip()
        img = cv2.imread(image_path, 1)
        label_img = cv2.imread(label_path, 1)
        im = np.zeros((width, height, 3), dtype='uint8')
        im[:, :, :] = 128
        lim = np.zeros((width, height, 3), dtype='uint8')

        if img.shape[0] >= img.shape[1]:
            scale = img.shape[0] / height
            new_width = int(img.shape[1] / scale)
            diff = (width - new_width) // 2

            img = cv2.resize(img, (new_width, height))
            label_img = cv2.resize(label_img, (new_width, height))

            im[:, diff:diff + new_width, :] = img
            lim[:, diff:diff + new_width, :] = label_img
        else:
            scale = img.shape[1] / width
            new_height = int(img.shape[0] / scale)
            diff = (height - new_height) // 2

            img = cv2.resize(img, (width, new_height))
            label_img = cv2.resize(label_img, (width, new_height))
            im[diff:diff + new_height, :, :] = img
            lim[diff:diff + new_height, :, :] = label_img
        lim = lim[:, :, 0]
        seg_labels = np.zeros((height, width, nClasses))
        for c in range(nClasses):
            seg_labels[:, :, c] = (lim == c).astype(int)
        im = np.float32(im) / 127.5 - 1
        seg_labels = np.reshape(seg_labels, (width * height, nClasses))
        x.append(im.transpose((2,0,1)))
        y.append(seg_labels.transpose((1,0)))

    return mx.nd.array(x), mx.nd.array(y)

class UnetDataIter(mx.io.DataIter):
    def __init__(self, root_path, path_file, batch_size, n_classes, input_width, input_height, train=True):
        f = open(path_file, 'r')
        self.items = f.readlines()
        f.close()

        self._provide_data = [['data', (batch_size, 3, input_width, input_height)]]
        self._provide_label = [['softmax_label', (batch_size, n_classes, input_width*input_height)]]

        self.root_path = root_path
        self.batch_size = batch_size
        self.num_batches = len(self.items) // batch_size
        self.n_classes = n_classes
        self.input_height = input_height
        self.input_width = input_width
        self.train = train

        self.reset()

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

        self.shuffled_items = []
        index = [n for n in range(len(self.items))]

        if self.train:
            random.shuffle(index)

        for i in range(len(self.items)):
            self.shuffled_items.append(self.items[index[i]])

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch == 0:
            print("")

        print("\r\033[k"+("Training " if self.train else "Validating ")+str(self.cur_batch)+"/"+str(self.num_batches), end=' ')

        if self.cur_batch < self.num_batches:
            data, label = get_batch(self.shuffled_items[self.cur_batch * self.batch_size:(self.cur_batch + 1) * self.batch_size], self.root_path, self.n_classes, self.input_height, self.input_width)
            self.cur_batch += 1

            return mx.io.DataBatch([data], [label])
        else:
            raise StopIteration

if __name__ =='__main__':
    root_path = '/datasets/'
    train_file = './data/seg_train.txt'
    val_file = './data/seg_test.txt'
    batch_size = 16
    n_classes = 2
    img_width = 256
    img_height = 256

    trainiter = UnetDataIter(root_path, train_file, batch_size, n_classes, img_width, img_height, True)

    while True:
        trainiter.next()

unet.py

import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"

import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt

def dice_coef(y_true, y_pred):
    intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=(1, 2, 3))
    return mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.sum(y_true, axis=(1, 2, 3)) + mx.sym.sum(y_pred, axis=(1, 2, 3)) + 1.))

def dice_coef_loss(y_true, y_pred):
    intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=1, )
    return -mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.broadcast_add(mx.sym.sum(y_true, axis=1), mx.sym.sum(y_pred, axis=1)) + 1.))

def build_unet(batch_size, input_width, input_height, train=True):
    data = mx.sym.Variable(name='data')
    label = mx.sym.Variable(name='softmax_label')

    # encode
    # 256x256
    conv1 = mx.sym.Convolution(data, num_filter=64, kernel=(3,3), pad=(1,1), name='conv1_1')
    conv1 = mx.sym.BatchNorm(conv1, name='bn1_1')
    conv1 = mx.sym.Activation(conv1, act_type='relu', name='relu1_1')
    conv1 = mx.sym.Convolution(conv1, num_filter=64, kernel=(3,3), pad=(1,1), name='conv1_2')
    conv1 = mx.sym.BatchNorm(conv1, name='bn1_2')
    conv1 = mx.sym.Activation(conv1, act_type='relu', name='relu1_2')
    pool1 = mx.sym.Pooling(conv1, kernel=(2,2), pool_type='max', name='pool1')
    # 128x128
    conv2 = mx.sym.Convolution(pool1, num_filter=128, kernel=(3,3), pad=(1,1), name='conv2_1')
    conv2 = mx.sym.BatchNorm(conv2, name='bn2_1')
    conv2 = mx.sym.Activation(conv2, act_type='relu', name='relu2_1')
    conv2 = mx.sym.Convolution(conv2, num_filter=128, kernel=(3,3), pad=(1,1), name='conv2_2')
    conv2 = mx.sym.BatchNorm(conv2, name='bn2_2')
    conv2 = mx.sym.Activation(conv2, act_type='relu', name='relu2_2')
    pool2 = mx.sym.Pooling(conv2, kernel=(2,2), pool_type='max', name='pool2')
    # 64x64
    conv3 = mx.sym.Convolution(pool2, num_filter=256, kernel=(3,3), pad=(1,1), name='conv3_1')
    conv3 = mx.sym.BatchNorm(conv3, name='bn3_1')
    conv3 = mx.sym.Activation(conv3, act_type='relu', name='relu3_1')
    conv3 = mx.sym.Convolution(conv3, num_filter=256, kernel=(3,3), pad=(1,1), name='conv3_2')
    conv3 = mx.sym.BatchNorm(conv3, name='bn3_2')
    conv3 = mx.sym.Activation(conv3, act_type='relu', name='relu3_2')
    pool3 = mx.sym.Pooling(conv3, kernel=(2,2), pool_type='max', name='pool3')
    # 32x32
    conv4 = mx.sym.Convolution(pool3, num_filter=256, kernel=(3,3), pad=(1,1), name='conv4_1')
    conv4 = mx.sym.BatchNorm(conv4, name='bn4_1')
    conv4 = mx.sym.Activation(conv4, act_type='relu', name='relu4_1')
    conv4 = mx.sym.Convolution(conv4, num_filter=256, kernel=(3,3), pad=(1,1), name='conv4_2')
    conv4 = mx.sym.BatchNorm(conv4, name='bn4_2')
    conv4 = mx.sym.Activation(conv4, act_type='relu', name='relu4_2')
    pool4 = mx.sym.Pooling(conv4, kernel=(2,2), pool_type='max', name='pool4')
    # 16x16
    conv5 = mx.sym.Convolution(pool4, num_filter=256, kernel=(3,3), pad=(1,1), name='conv5_1')
    conv5 = mx.sym.BatchNorm(conv5, name='bn5_1')
    conv5 = mx.sym.Activation(conv5, act_type='relu', name='relu5_1')
    conv5 = mx.sym.Convolution(conv5, num_filter=256, kernel=(3,3), pad=(1,1), name='conv5_2')
    conv5 = mx.sym.BatchNorm(conv5, name='bn5_2')
    conv5 = mx.sym.Activation(conv5, act_type='relu', name='relu5_2')
    pool5 = mx.sym.Pooling(conv5, kernel=(2,2), pool_type='max', name='pool5')
    # 8x8

    # decode
    trans_conv6 = mx.sym.Deconvolution(pool5, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv6')
    up6 = mx.sym.concat(*[trans_conv6, conv5], dim=1, name='concat6')
    conv6 = mx.sym.Convolution(up6, num_filter=256, kernel=(3,3), pad=(1,1), name='conv6_1')
    conv6 = mx.sym.BatchNorm(conv6, name='bn6_1')
    conv6 = mx.sym.Activation(conv6, act_type='relu', name='relu6_1')
    conv6 = mx.sym.Convolution(conv6, num_filter=256, kernel=(3,3), pad=(1,1), name='conv6_2')
    conv6 = mx.sym.BatchNorm(conv6, name='bn6_2')
    conv6 = mx.sym.Activation(conv6, act_type='relu', name='relu6_2')

    trans_conv7 = mx.sym.Deconvolution(conv6, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv7')
    up7 = mx.sym.concat(*[trans_conv7, conv4], dim=1, name='concat7')
    conv7 = mx.sym.Convolution(up7, num_filter=256, kernel=(3,3), pad=(1,1), name='conv7_1')
    conv7 = mx.sym.BatchNorm(conv7, name='bn7_1')
    conv7 = mx.sym.Activation(conv7, act_type='relu', name='relu7_1')
    conv7 = mx.sym.Convolution(conv7, num_filter=256, kernel=(3,3), pad=(1,1), name='conv7_2')
    conv7 = mx.sym.BatchNorm(conv7, name='bn7_2')
    conv7 = mx.sym.Activation(conv7, act_type='relu', name='relu7_2')

    trans_conv8 = mx.sym.Deconvolution(conv7, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv8')
    up8 = mx.sym.concat(*[trans_conv8, conv3], dim=1, name='concat8')
    conv8 = mx.sym.Convolution(up8, num_filter=256, kernel=(3,3), pad=(1,1), name='conv8_1')
    conv8 = mx.sym.BatchNorm(conv8, name='bn8_1')
    conv8 = mx.sym.Activation(conv8, act_type='relu', name='relu8_1')
    conv8 = mx.sym.Convolution(conv8, num_filter=256, kernel=(3,3), pad=(1,1), name='conv8_2')
    conv8 = mx.sym.BatchNorm(conv8, name='bn8_2')
    conv8 = mx.sym.Activation(conv8, act_type='relu', name='relu8_2')

    trans_conv9 = mx.sym.Deconvolution(conv8, num_filter=128, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv9')
    up9 = mx.sym.concat(*[trans_conv9, conv2], dim=1, name='concat9')
    conv9 = mx.sym.Convolution(up9, num_filter=128, kernel=(3,3), pad=(1,1), name='conv9_1')
    conv9 = mx.sym.BatchNorm(conv9, name='bn9_1')
    conv9 = mx.sym.Activation(conv9, act_type='relu', name='relu9_1')
    conv9 = mx.sym.Convolution(conv9, num_filter=128, kernel=(3,3), pad=(1,1), name='conv9_2')
    conv9 = mx.sym.BatchNorm(conv9, name='bn9_2')
    conv9 = mx.sym.Activation(conv9, act_type='relu', name='relu9_2')

    trans_conv10 = mx.sym.Deconvolution(conv9, num_filter=64, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv10')
    up10 = mx.sym.concat(*[trans_conv10, conv1], dim=1, name='concat10')
    conv10 = mx.sym.Convolution(up10, num_filter=64, kernel=(3,3), pad=(1,1), name='conv10_1')
    conv10 = mx.sym.BatchNorm(conv10, name='bn10_1')
    conv10 = mx.sym.Activation(conv10, act_type='relu', name='relu10_1')
    conv10 = mx.sym.Convolution(conv10, num_filter=64, kernel=(3,3), pad=(1,1), name='conv10_2')
    conv10 = mx.sym.BatchNorm(conv10, name='bn10_2')
    conv10 = mx.sym.Activation(conv10, act_type='relu', name='relu10_2')

    ###
    conv11 = mx.sym.Convolution(conv10, num_filter=2, kernel=(1,1), name='conv11_1')
    conv11 = mx.sym.sigmoid(conv11, name='softmax')

    net = mx.sym.Reshape(conv11, (batch_size, 2, input_width*input_height))

    if train:
        loss = mx.sym.MakeLoss(dice_coef_loss(label, net), normalization='batch')
        mask_output = mx.sym.BlockGrad(conv11, 'mask')
        out = mx.sym.Group([loss, mask_output])
    else:
        # mask_output = mx.sym.BlockGrad(conv11, 'mask')
        out = mx.sym.Group([conv11])

    return out

trainunet.py

import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"

import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet

def main():
    root_path = '../datasets/'
    train_file = './data/seg_train.txt'
    val_file = './data/seg_test.txt'
    batch_size = 16
    n_classes = 2
    # img_width = 256
    # img_height = 256
    img_width = 96
    img_height = 96

    train_iter = UnetDataIter(root_path, train_file, batch_size, n_classes, img_width, img_height, True)
    val_iter = UnetDataIter(root_path, val_file, batch_size, n_classes, img_width, img_height, False)

    ctx = [mx.gpu(0)]

    unet_sym = build_unet(batch_size, img_width, img_height)
    unet = mx.mod.Module(unet_sym, context=ctx, data_names=('data',), label_names=('softmax_label',))
    unet.bind(data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=[['softmax_label', (batch_size, n_classes, img_width*img_height)]])
    unet.init_params(mx.initializer.Xavier(magnitude=6))

    unet.init_optimizer(optimizer = 'adam',
                                   optimizer_params=(
                                       ('learning_rate', 1E-4),
                                       ('beta1', 0.9),
                                       ('beta2', 0.99)
                                  ))

    # unet.fit(train_iter,  # train data
    #               eval_data=val_iter,  # validation data
    #               #optimizer='sgd',  # use SGD to train
    #               #optimizer_params={'learning_rate':0.1},  # use fixed learning rate
    #               eval_metric='acc',  # report accuracy during training
    #               batch_end_callback = mx.callback.Speedometer(batch_size, 1), # output progress for each 100 data batches
    #               num_epoch=10)  # train for at most 10 dataset passes

    epochs = 20
    smoothing_constant = .01
    curr_losses = []
    moving_losses = []
    i = 0
    best_val_loss = np.inf
    for e in range(epochs):
        while True:
            try:
                batch = next(train_iter)
            except StopIteration:
                train_iter.reset()
                break
            unet.forward_backward(batch)
            loss = unet.get_outputs()[0]
            unet.update()
            curr_loss = F.mean(loss).asscalar()
            curr_losses.append(curr_loss)
            moving_loss = (curr_loss if ((i == 0) and (e == 0))
                                   else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)
            moving_losses.append(moving_loss)
            i += 1
        val_losses = []
        for batch in val_iter:
            unet.forward(batch)
            loss = unet.get_outputs()[0]
            val_losses.append(F.mean(loss).asscalar())
        val_iter.reset()
        val_loss = np.mean(val_losses)
        print("\nEpoch %i: Moving Training Loss %0.5f, Validation Loss %0.5f" % (e, moving_loss, val_loss))

        unet.save_checkpoint('./unet_person_segmentation', e)

if __name__ =='__main__':
    main()


  以上是训练代码。

  预测代码如下predict.py

import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import sys
import cv2
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet

def post_process_mask(label, img_cols, img_rows, n_classes, p=0.5):
    pr = label.reshape(n_classes, img_cols, img_rows).transpose([1,2,0]).argmax(axis=2)
    return (pr*255).asnumpy()

def load_image(img, width, height):
    im = np.zeros((height, width, 3), dtype='uint8')
    im[:, :, :] = 128

    if img.shape[0] >= img.shape[1]:
        scale = img.shape[0] / height
        new_width = int(img.shape[1] / scale)
        diff = (width - new_width) // 2
        img = cv2.resize(img, (new_width, height))

        im[:, diff:diff + new_width, :] = img
    else:
        scale = img.shape[1] / width
        new_height = int(img.shape[0] / scale)
        diff = (height - new_height) // 2

        img = cv2.resize(img, (width, new_height))
        im[diff:diff + new_height, :, :] = img

    im = np.float32(im) / 127.5 - 1

    return [im.transpose((2,0,1))]

def main():
    batch_size = 16
    n_classes = 2
    # img_width = 256
    # img_height = 256
    img_width = 96
    img_height = 96

    ctx = [mx.gpu(0)]

    # sym, arg_params, aux_params = mx.model.load_checkpoint('unet_person_segmentation', 20)
    # unet_sym = build_unet(batch_size, img_width, img_height, False)
    # unet = mx.mod.Module(symbol=unet_sym, context=ctx, label_names=None)

    sym, arg_params, aux_params = mx.model.load_checkpoint('unet_person_segmentation', 0)
    unet = mx.mod.Module(symbol=sym, context=ctx, label_names=None)

    unet.bind(for_training=False, data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=unet._label_shapes)
    unet.set_params(arg_params, aux_params, allow_missing=True)

    testimg = cv2.imread(sys.argv[1], 1)
    img = load_image(testimg, img_width, img_height)
    unet.predict(mx.io.NDArrayIter(data=[img]))

    outputs = unet.get_outputs()[0]
    cv2.imshow('test', testimg)
    cv2.imshow('mask', post_process_mask(outputs[0], img_width, img_height, n_classes))
    cv2.waitKey()

if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("illegal parameters")
        sys.exit(0)

    main()

  剥离softmax保存参数用于ncnn模型转换,train2infer.py

扫描二维码关注公众号,回复: 3634457 查看本文章
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import sys
import cv2
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet

def main():
    batch_size = 16
    n_classes = 2
    # img_width = 256
    # img_height = 256
    img_width = 96
    img_height = 96

    ctx = [mx.gpu(0)]

    sym, arg_params, aux_params = mx.model.load_checkpoint(sys.argv[1], int(sys.argv[2]))

    unet_sym = build_unet(batch_size, img_width, img_height, False)
    unet = mx.mod.Module(symbol=unet_sym, context=ctx, label_names=None)
    unet.bind(for_training=False, data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=unet._label_shapes)
    unet.set_params(arg_params, aux_params, allow_missing=True)

    unet.save_checkpoint('./unet_person_segmentation', 0)

if __name__ == '__main__':
    if len(sys.argv) < 3:
        print("illegal parameters")
        sys.exit(0)

    main()

猜你喜欢

转载自blog.csdn.net/xiexiecn/article/details/83029787