mxnet学习(一)----利用自己的数据集图像分类

这段时间小编参加了一个叫prcv的挑战赛,其中就有一个项目是多标签图像分类,可是小编一直使用的是caffe框架,这给这个任务带来了比较大的挑战,据说可以通过修改caffe源码来实现多标签分类问题,但是我觉得太麻烦了,也不想去修改我的caffe源码,毕竟caffe也不太好安装,tensorflow,keas等框架小编半生不熟,经过搜索,打听到mxnet比较适合解决这个问题,于是小编开始摸索mxnet,一查官方文档,基本有了了解,特此记录下。
深度学习框架几乎都是大同小异,学习这些无非就是几个步骤
1、准备数据集
2、转换成框架需要的数据格式
3、搭建模型开始训练
4、利用生成的模型进行预测
5、学习迁移学习等
6、目标检测
7、RNN等
接下来我们开始吧
(一)安装mxnet
此过程直接跳过吧,网上太多了,小编非常建议源码安装不要使用pip安装,小编在这给出参考链接,其中cuda等如果是cpu可跳过

https://www.cnblogs.com/whu-zeng/p/6160312.html

(二)准备数据集
本次实验使用的数据集来自小编的各种搜集,其中一个文件夹里是圣诞老人,一个是负样本,小编将要实现二分类,多分类同样可参考,文件结构如图所示,image里包含两类,1代表负样本,0代表圣诞老人
这里写图片描述
这里写图片描述
这样裸着的jpg图片不能直接作为训练数据,需要转化成比较规范的格式。首先,输入CNN的原始图片应当具有同样的尺寸,同样的通道数(如RGB为3通道)等等。另外,mxnet推荐所有的数据都应该以某种DataIter的形式呈现,这样我们通过mxnet的接口就可以很方便地进行训练(见后文)。
(三 转换数据集格式)
mxnet支持将一种.rec格式的数据集直接导入为DataIter,同时提供了一种工具可以将一个裸的图片数据集转化成.rec格式。为了简化过程,我们采用mxnet提供的小工具来将我们的数据集转化为.rec格式(此举与caffe下的create_imagenet.sh是一模一样的)。

mxnet提供的工具在mxnet/tools目录下,如果im2rec.cc/cpp已经编译成可执行文件,则可以使用该可执行文件;否则可以使用同目录下的im2rec.py。我使用的是im2rec.py。
我们找到mxnet/tools/im2rec.py,其源码有错误,小编修改了下确保无误

# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import sys

curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../python"))
import mxnet as mx
import random
import argparse
import cv2
import time


def list_image(root, recursive, exts):
    image_list = []
    if recursive:
        cat = {}
        for path, subdirs, files in os.walk(root, followlinks=True):
            subdirs.sort()
            print(len(cat), path)
            for fname in files:
                fpath = os.path.join(path, fname)
                suffix = os.path.splitext(fname)[1].lower()
                if os.path.isfile(fpath) and (suffix in exts):
                    if path not in cat:
                        cat[path] = len(cat)
                    yield (len(image_list), os.path.relpath(fpath, root), cat[path])
    else:
        for fname in os.listdir(root):
            fpath = os.path.join(root, fname)
            suffix = os.path.splitext(fname)[1].lower()
            if os.path.isfile(fpath) and (suffix in exts):
                yield (len(image_list), os.path.relpath(fpath, root), 0)

def write_list(path_out, image_list):
    with open(path_out, 'w') as fout:
        for i, item in enumerate(image_list):
            line = '%d\t' % item[0]
            for j in item[2:]:
                line += '%f\t' % j
            line += '%s\n' % item[1]
            fout.write(line)

def make_list(args):
    image_list = list_image(args.root, args.recursive, args.exts)
    image_list = list(image_list)
    if args.shuffle is True:
        random.seed(100)
        random.shuffle(image_list)
    N = len(image_list)
    chunk_size = (N + args.chunks - 1) / args.chunks
    for i in xrange(args.chunks):
        chunk = image_list[i * chunk_size:(i + 1) * chunk_size]
        if args.chunks > 1:
            str_chunk = '_%d' % i
        else:
            str_chunk = ''
        sep = int(chunk_size * args.train_ratio)
        sep_test = int(chunk_size * args.test_ratio)
        write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
        write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])
        write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])

def read_list(path_in):
    with open(path_in) as fin:
        while True:
            line = fin.readline()
            if not line:
                break
            line = [i.strip() for i in line.strip().split('\t')]
            item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]
            yield item

def image_encode(args, item, q_out):
    try:
        img = cv2.imread(os.path.join(args.root, item[1]), args.color)
    except:
        print('imread error:', item[1])
        return
    if img is None:
        print('read none error:', item[1])
        return
    if args.center_crop:
        if img.shape[0] > img.shape[1]:
            margin = (img.shape[0] - img.shape[1]) / 2;
            img = img[margin:margin + img.shape[1], :]
        else:
            margin = (img.shape[1] - img.shape[0]) / 2;
            img = img[:, margin:margin + img.shape[0]]
    if args.resize:
        if img.shape[0] > img.shape[1]:
            newsize = (args.resize, img.shape[0] * args.resize / img.shape[1])
        else:
            newsize = (img.shape[1] * args.resize / img.shape[0], args.resize)
        img = cv2.resize(img, newsize)
    if len(item) > 3 and args.pack_label:
        header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
    else:
        header = mx.recordio.IRHeader(0, item[2], item[0], 0)

    try:
        s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
        q_out.put((s, item))
    except Exception, e:
        print('pack_img error:', item[1], e)
        return

def read_worker(args, q_in, q_out):
    while True:
        item = q_in.get()
        if item is None:
            break
        image_encode(args, item, q_out)

def write_worker(q_out, fname, working_dir):
    pre_time = time.time()
    count = 0
    fname_rec = os.path.basename(fname)
    fname_rec = os.path.splitext(fname)[0] + '.rec'
    fout = open(fname+'.tmp', 'w')
    record = mx.recordio.MXRecordIO(os.path.join(working_dir, fname_rec), 'w')
    while True:
        deq = q_out.get()
        if deq is None:
            break
        s, item = deq
        record.write(s)

        line = '%d\t' % item[0]
        for j in item[2:]:
            line += '%f\t' % j
        line += '%s\n' % item[1]
        fout.write(line)

        if count % 1000 == 0:
            cur_time = time.time()
            print('time:', cur_time - pre_time, ' count:', count)
            pre_time = cur_time
        count += 1
    os.rename(fname+'.tmp', fname)

def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Create an image list or \
        make a record database by reading from an image list')
    parser.add_argument('prefix', help='prefix of input/output lst and rec files.')
    parser.add_argument('root', help='path to folder containing images.')

    cgroup = parser.add_argument_group('Options for creating image lists')
    cgroup.add_argument('--list', type=bool, default=False,
                        help='If this is set im2rec will create image list(s) by traversing root folder\
        and output to <prefix>.lst.\
        Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec')
    cgroup.add_argument('--exts', type=list,action='append',default=['.jpeg', '.jpg'],
                        help='list of acceptable image extensions.')
    cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.')
    cgroup.add_argument('--train-ratio', type=float, default=1.0,
                        help='Ratio of images to use for training.')
    cgroup.add_argument('--test-ratio', type=float, default=0,
                        help='Ratio of images to use for testing.')
    cgroup.add_argument('--recursive', type=bool, default=False,
                        help='If true recursively walk through subdirs and assign an unique label\
        to images in each folder. Otherwise only include images in the root folder\
        and give them label 0.')

    rgroup = parser.add_argument_group('Options for creating database')
    rgroup.add_argument('--resize', type=int, default=0,
                        help='resize the shorter edge of image to the newsize, original images will\
        be packed by default.')
    rgroup.add_argument('--center-crop', type=bool, default=False,
                        help='specify whether to crop the center image to make it rectangular.')
    rgroup.add_argument('--quality', type=int, default=80,
                        help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9')
    rgroup.add_argument('--num-thread', type=int, default=1,
                        help='number of thread to use for encoding. order of images will be different\
        from the input list if >1. the input list will be modified to match the\
        resulting order.')
    rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1],
                        help='specify the color mode of the loaded image.\
        1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\
        0: Loads image in grayscale mode.\
        -1:Loads image as such including alpha channel.')
    rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],
                        help='specify the encoding of the images.')
    rgroup.add_argument('--shuffle', default=True, help='If this is set as True, \
        im2rec will randomize the image order in <prefix>.lst')
    rgroup.add_argument('--pack-label', default=False,
        help='Whether to also pack multi dimensional label in the record file') 
    args = parser.parse_args()
    args.prefix = os.path.abspath(args.prefix)
    args.root = os.path.abspath(args.root)
    return args

if __name__ == '__main__':
    args = parse_args()
    if args.list:
        make_list(args)
    else:
        if os.path.isdir(args.prefix):
            working_dir = args.prefix
        else:
            working_dir = os.path.dirname(args.prefix)
        files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir)
                    if os.path.isfile(os.path.join(working_dir, fname))]
        count = 0
        for fname in files:
            if fname.startswith(args.prefix) and fname.endswith('.lst'):
                print('Creating .rec file from', fname, 'in', working_dir)
                count += 1
                image_list = read_list(fname)
                # -- write_record -- #
                try:
                    import multiprocessing
                    q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
                    q_out = multiprocessing.Queue(1024)
                    read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \
                                    for i in range(args.num_thread)]
                    for p in read_process:
                        p.start()
                    write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir))
                    write_process.start()

                    for i, item in enumerate(image_list):
                        q_in[i % len(q_in)].put(item)
                    for q in q_in:
                        q.put(None)
                    for p in read_process:
                        p.join()

                    q_out.put(None)
                    write_process.join()
                except ImportError:
                    print('multiprocessing not available, fall back to single threaded encoding')
                    import Queue
                    q_out = Queue.Queue()
                    fname_rec = os.path.basename(fname)
                    fname_rec = os.path.splitext(fname)[0] + '.rec'
                    record = mx.recordio.MXRecordIO(os.path.join(working_dir, fname_rec), 'w')
                    cnt = 0
                    pre_time = time.time()
                    for item in image_list:
                        image_encode(args, item, q_out)
                        if q_out.empty():
                            continue
                        _, s, _ = q_out.get()
                        record.write(s)
                        if cnt % 1000 == 0:
                            cur_time = time.time()
                            print('time:', cur_time - pre_time, ' count:', cnt)
                            pre_time = cur_time
                        cnt += 1
        if not count:
            print('Did not find and list file with prefix %s'%args.prefix)

我们开始运行生成list文件,里面存储了图像的信息(第一个路径是生成名,第二个是图像路径)

 python im2rec.py --recursive=True  --exts=.jpg --list=True --train-ratio 0.8 --test-ratio 0.1 /home/xiaorun/mxnet/tools/mydata /home/xiaorun/mxnet/tools/images/

接下来我们利用下面脚本生成.rec格式便于Mxnet训练

 python im2rec.py mydata_train.lst images/ --resize 224 --num-thread 8

(三)接下来我们可以搭建模型训练自己数据集了
首先我们找到自己mxnet路径下 /home/xiaorun/mxnet/example/image-classification/
里面包含了很多训练脚本以及网络模型,其中symbols是目前的主流模型,小编自己稍微瞄了几眼,发现还是挺新的,也特别方便,我们用什么模型,直接调用里面的model即可,common里存放的是数据处理记忆参数的传入,大家可详细看,我以traincifar.py为例,我们只需要修改里面的数据集与模型,将数据集和模型换成自己的即可,为了方便,我们在mxnet/example/image-classification/
新建一个mynet文件用来训练自己的数据集,把traincifar.py symbols common考进去
其中traincifar.py我们要修改数据集部分,以及一些默认参数

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet, data, fit
from common.util import download_file
import mxnet as mx

def download_cifar10():
    data_dir="/home/xiaorun/mxnet/tools/"
    fnames = (os.path.join(data_dir, "mydata_train.rec"),
              os.path.join(data_dir, "mydata_test.rec"))
    #download_file('http://data.mxnet.io/data/cifar10/cifar10_val.rec', fnames[1])
    #download_file('http://data.mxnet.io/data/cifar10/cifar10_train.rec', fnames[0])
    return fnames

if __name__ == '__main__':
    # download data
    (train_fname, val_fname) = download_cifar10()

    # parse args
    parser = argparse.ArgumentParser(description="train ",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    fit.add_fit_args(parser)
    data.add_data_args(parser)
    data.add_data_aug_args(parser)
    #data.set_data_aug_level(parser, 2)
    parser.set_defaults(
        # network,define your net
        network        = 'lenet',
        num_layers     = 110,
        # data
        data_train     = train_fname,
        data_val       = val_fname,
        #your classes
        num_classes    = 2,
        num_examples  = 737,
        image_shape    = '3,224,224',
        pad_size       = 4,
        # train
        batch_size     = 1,
        num_epochs     = 15,
        lr             = .05,
        lr_step_epochs = '5,10',
    )
    args = parser.parse_args()

    # load network
    from importlib import import_module
    net = import_module('symbols.'+args.network)
    sym = net.get_symbol(**vars(args))

    # train
    fit.fit(args, sym, data.get_rec_iter)

猜你喜欢

转载自blog.csdn.net/xiao__run/article/details/81030473