深度学习_商品检测数据集训练(7)

2.1 目标检测数据集

学习目标

  • 目标
    • 了解常用目标检测数据集
    • 了解数据集构成
  • 应用

2.1.1 常用目标检测数据集

  • pascal Visual Object Classes
扫描二维码关注公众号,回复: 12682422 查看本文章

2.1 目标检测数据集

学习目标

  • 目标
    • 了解常用目标检测数据集
    • 了解数据集构成
  • 应用

2.1.1 常用目标检测数据集

  • pascal Visual Object Classes

VOC数据集是目标检测经常用的一个数据集,从05年到12年都会举办比赛(比赛有task: Classification、Detection、Segmentation、PersonLayout),主要由VOC2007和VOC2012两个数据集

注:

官网地址:http://host.robots.ox.ac.uk/pascal/VOC/

下载地址:https://pjreddie.com/projects/pascal-voc-dataset-mirror/

  • Open Images Dataset V4

2018年发布了包含在 190 万张图片上针对 600 个类别的 1540 万个边框盒,这也是现有最大的具有对象位置注释的数据集。这些边框盒大部分都是由专业注释人员手动绘制的,确保了它们的准确性和一致性。

谷歌的数据集类目较多涵盖范围广,但是文件过多,处理起来比较麻烦,所以选择目前使用较多并且已经成熟的pascavoc数据集

2.1.2 pascal voc数据集介绍

通常使用较多的为VOC2007数据集,总共9963张图片,需要判定的总物体类别数量为20个对象类别是:

  • 人:
  • 动物:鸟,猫,牛,狗,马,羊
  • 车辆:飞机,自行车,船,公共汽车,汽车,摩托车,火车
  • 室内:瓶子,椅子,餐桌,盆栽,沙发,电视/显示器
  • 文件结构

  • 文件内容
    • Annotations: 图像中的目标标注信息xml格式
    • JPEGImages:所有图片(VOC2007中总共有9963张,训练有5011张,测试有4952张)

2.1.3 XML

以下是一个标准的物体检测标记结果格式,这就是用于训练的物体标记结果。其中有几个重点内容是后续在处理图像标记结果需要关注的。

  • size:
    • 图片尺寸大小,宽、高、通道数
  • object:
    • name:被标记物体的名称
    • bndbox:标记物体的框大小

如下例子:为000001.jpg这张图片,其中有两个物体被标记

<annotation>
    <folder>VOC2007</folder>
    <filename>000001.jpg</filename># 文件名
    <source># 文件来源
        <database>The VOC2007 Database</database>
        <annotation>PASCAL VOC2007</annotation>
        <image>flickr</image>
        <flickrid>341012865</flickrid>
    </source>
    <owner>
        <flickrid>Fried Camels</flickrid>
        <name>Jinky the Fruit Bat</name>
    </owner>
    <size># 文件尺寸,包括宽、高、通道数
        <width>353</width>
        <height>500</height>
        <depth>3</depth>
    </size>
    <segmented>0</segmented># 是否用于目标分割
    <object># 真实标记的物体
        <name>dog</name># 目标类别名称
        <pose>Left</pose>
        <truncated>1</truncated># 是否截断,这个目标因为各种原因没有被框完整(被截断了),比如说一辆车有一部分在画外面
        <difficult>0</difficult># 表明这个待检测目标很难识别,有可能是虽然视觉上很清楚,但是没有上下文的话还是很难确认它属于哪个分类,标为difficult的目标在测试评估中一般会被忽略
        <bndbox># bounding-box
            <xmin>48</xmin>
            <ymin>240</ymin>
            <xmax>195</xmax>
            <ymax>371</ymax>
        </bndbox>
    </object>
    <object># 真实标记的第二个物体
        <name>person</name>
        <pose>Left</pose>
        <truncated>1</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>8</xmin>
            <ymin>12</ymin>
            <xmax>352</xmax>
            <ymax>498</ymax>
        </bndbox>
    </object>
</annotation>

2.2 目标数据集标记

学习目标

  • 目标

    • 了解数据集标记的需求

    • 知道labelimg的标记使用

  • 应用

    • 应用labelimg完成商品数据集的标记

为什么要进行数据集标记呢?

1、提供给训练的数据样本,图片和目标真是实结果

2、特定的场景都会缺少标记图片

2.2.1 数据集标记工具介绍

2.2.1.1 介绍

LabelImg是一个图形图像注释工具。它是用Python编写的,并使用Qt作为其图形界面。注释以PASCAL VOC格式保存为XML文件,这是ImageNet使用的格式。

注:官网:https://github.com/tzutalin/labelImg

2.2.1.2 安装

官网给出了不同平台的安装教程,由于教程过于粗略。安装细节参考安装教程本地文件

参考本地文件:

2.2.2 商品数据集标记

在这里我们只是体验标记的过程,那么对于标记这个费时费力的工作,一般会有专门的数据标记团队去做,也称之为打标签,标记师。特别是缺乏具体应用场景的训练数据的时候。

2.2.2.1 需求介绍

首先在确定标记之前的需求,本项目以商品数据为例,需要明确的有

  • 1、商品图片
  • 2、需要被标记物体有哪些

我们确定了8种类别的商品(如需更细致,可将类别商品扩大),如下图

2.2.2.2 标记

使用lableimg进行商品数据集标记

  • 运行labelimg
python labelImg.py

打开如下结果

  • 对图片中的物体进行标记

标记原则为图片中所出现的物体与我们确定的8个类别物体相匹配即可

  • 按下ctrl+s键保存,软件将会保存为默认XML文件格式(XML文件名与图片文件名保持一致方便后续处理)

其中关于(xmin,ymin,xmax,ymax)我们已经解释过,可通过软件标记的时候观察是否一致

2.2.3 总结

  • 掌握labelimg的标注使用

5.1 项目训练结构介绍

学习目标

  • 目标
  • 应用

5.1.1 项目目录结构

  • ckpt:分为预训练与微调模型
  • datasets:放训练原始数据以及存储数据、读取数据代码以及模型priorbox
  • servingmodel:模型部署使用的模型位置
  • export_serving_model:导出TFserving指定模型类型
  • train_ssd:训练模型代码逻辑

5.2 标注数据读取与存储

学习目标

  • 目标
  • 应用
    • 应用XML工具进行标签数据读取以及存储

5.2.1 案例:xml读取本地文件存储到pkl

  • ElementTree工具使用,解析xml结构
  • 保存物体坐标结果以及类别
    • pickle工具导出

5.2.1.1 解析结构

  • 导入
from xml.etree import ElementTree
  • 处理XML库
    • import xml.etree.ElementTree as ET
      • tree = et.parse(filename):形成树状结构
      • tree.getroot():获取树结构的根部分
      • root.find与findall()进行查询XML每个标签的内容.text

定义解析xml结构类,

class XmlProcess(object):

    def __init__(self, data_path):
        self.path_prefix = data_path
        self.num_classes = 8
        self.data = dict()

进行preprocess_xml处理

    def preprocess_xml(self):
        # 找到文件名字
        filenames = os.listdir(self.path_prefix)
        for filename in filenames:
            # XML解析根路径
            tree = ElementTree.parse(self.path_prefix + filename)
            root = tree.getroot()
            bounding_boxes = []
            one_hot_classes = []
            size_tree = root.find('size')
            width = float(size_tree.find('width').text)
            height = float(size_tree.find('height').text)

            # 每个图片标记的对象进行坐标获取
            for object_tree in root.findall('object'):
                for bounding_box in object_tree.iter('bndbox'):
                    xmin = float(bounding_box.find('xmin').text)/width
                    ymin = float(bounding_box.find('ymin').text)/height
                    xmax = float(bounding_box.find('xmax').text)/width
                    ymax = float(bounding_box.find('ymax').text)/height
                bounding_box = [xmin, ymin, xmax, ymax]
                bounding_boxes.append(bounding_box)
                class_name = object_tree.find('name').text

                # 将类别进行one_hot编码
                one_hot_class = self.on_hot(class_name)
                one_hot_classes.append(one_hot_class)

            image_name = root.find('filename').text
            bounding_boxes = np.asarray(bounding_boxes)
            one_hot_classes = np.asarray(one_hot_classes)

            # 存储图片标注的结果对应的名字,以及图片的标注数据(4个坐标以及onehot编码)
            image_data = np.hstack((bounding_boxes, one_hot_classes))
            self.data[image_name] = image_data

one_hot编码函数

    def on_hot(self, name):
        one_hot_vector = [0] * self.num_classes
        if name == 'clothes':
            one_hot_vector[0] = 1
        elif name == 'pants':
            one_hot_vector[1] = 1
        elif name == 'shoes':
            one_hot_vector[2] = 1
        elif name == 'watch':
            one_hot_vector[3] = 1
        elif name == 'phone':
            one_hot_vector[4] = 1
        elif name == 'audio':
            one_hot_vector[5] = 1
        elif name == 'computer':
            one_hot_vector[6] = 1
        elif name == 'books':
            one_hot_vector[7] = 1
        else:
            print('unknown label: %s' % name)
        return one_hot_vector

使用preprocess进行本地保存到pickle文件

if __name__ == '__main__':
    xp = XmlProcess('/Users/huxinghui/workspace/ml/detection/ssd_detection/ssd/datasets/commodity/Annotations/')
    xp.preprocess_xml()
    pickle.dump(xp.data, open('./commodity_gt.pkl', 'wb'))

===========================================

5.3 训练

学习目标

  • 目标
  • 应用
    • 应用API完成商品数据集的训练过程

5.3.1 案例训练结果

  • 文件

5.3.2 案例思路

  • image_generator:获取图片数据标注数据生成器
    • 标注数据分割
  • 初始化模型参数以及冻结部分结构
  • compile与fit_generator

5.3.2.1 获取Generator

导入工具

from utils.detection_generate import Generator
from utils.ssd_utils import BBoxUtility
from nets.ssd_net import SSD300

import numpy as np
import pickle

定义类,进行初始化网络基础参数

class SSDTrain(object):

    def __init__(self, num_classes=9, input_shape=(300, 300, 3), epoch=50):
        self.num_classes = num_classes
        self.input_shape = input_shape
        self.epoch = epoch

        # prior box读取工具
        priors = pickle.load(open('./datasets/prior_boxes_ssd300.pkl', 'rb'))
        self.bbox_util = BBoxUtility(self.num_classes, priors)

        self.path_prefix = './datasets/commodity/JPEGImages/'

        self.model = SSD300(self.input_shape, num_classes=self.num_classes)

    def image_generator(self):

        # 获取标记数据,分成训练集与测试集
        gt = pickle.load(open('./datasets/commodity_gt.pkl', 'rb'))
        keys = sorted(gt.keys())
        num_train = int(round(0.8 * len(keys)))
        train_keys = keys[:num_train]
        val_keys = keys[num_train:]

        # Generator获取数据
        gen = Generator(gt, self.bbox_util, 16, self.path_prefix,
                        train_keys, val_keys,
                        (self.input_shape[0], self.input_shape[1]), do_crop=False)

5.3.2.3 初始化网络参数,微调网络

进行模型参数加载以及模型的结构freeze

    def init_model_param(self):
        """
        初始化模型参数
        :return:
        """

        self.model.load_weights('./ckpt/pre_trained/weights_SSD300.hdf5', by_name=True)

        # 选择freeze部分结构
        freeze = ['input_1', 'conv1_1', 'conv1_2', 'pool1',
                  'conv2_1', 'conv2_2', 'pool2',
                  'conv3_1', 'conv3_2', 'conv3_3', 'pool3']
        for L in self.model.layers:
            if L.name in freeze:
                L.trainable = False

5.3.2.4 设置训练参数以及fit

  • 使用adam默认算法

需要导入相关库,计算损失

from utils.ssd_losses import MultiboxLoss

手工

    def compile(self):
        """
        配置训练参数
        :return:
        """

        optimizer = keras.optimizers.Adam()
        self.model.compile(optimizer=optimizer,
                           loss=MultiboxLoss(self.num_classes, neg_pos_ratio=2.0).compute_loss)

    def fit_generator(self, gen):
        """
        训练
        :param gen: 图片数据生成器
        :return:
        """
        # 配置回调
        callbacks = [
            keras.callbacks.ModelCheckpoint('./ckpt/fine_tuning/weights.{epoch:02d}-{val_loss:.2f}.hdf5',
                                            verbose=1,
                                            save_weights_only=True)]
        self.model.fit_generator(gen.generate(True), gen.train_batches,
                                 self.epoch, verbose=1,
                                 callbacks=callbacks,
                                 validation_data=gen.generate(False),
                                 nb_val_samples=gen.val_batches)

5.3.3 多GPU训练代码修改

  • 在tf.keras中直接使用DistributionStrategy
    def compile(self):
        """
        配置训练参数
        :return:
        """
        distribution = tf.contrib.distribute.MirroredStrategy()

        optimizer = keras.optimizers.Adam()
        self.model.compile(optimizer=optimizer,
                           loss=MultiboxLoss(self.num_classes, neg_pos_ratio=2.0).compute_loss,
                           distribution=distribution)

5.4 本地预测测试

学习目标

  • 目标
  • 应用
    • 应用模型对本地商品图片进行预测

5.4.1 预测代码

修改源代码self.classes_name的目标个数:按照建立one_hot编码的顺序。

self.classes_name = ['clothes', 'pants', 'shoes', 'watch', 'phone',
                     'audio', 'computer', 'books']

修改读取训练过后的模型

from tensorflow import keras
from keras.applications.imagenet_utils import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import matplotlib.pyplot as plt
import numpy as np

from nets.ssd_net import SSD300
from utils.ssd_utils import BBoxUtility
from scipy.misc import imread
import os


class SSDTrain(object):

    def __init__(self):

        self.classes_name = ['Aeroplane', 'Bicycle', 'Bird', 'Boat', 'Bottle',
                               'Bus', 'Car', 'Cat', 'Chair', 'Cow', 'Diningtable',
                               'Dog', 'Horse', 'Motorbike', 'Person', 'Pottedplant',
                               'Sheep', 'Sofa', 'Train', 'Tvmonitor']

        self.classes_nums = len(self.classes_name) + 1
        self.input_shape = (300, 300, 3)

    def test(self):

        model = SSD300(self.input_shape, num_classes=self.classes_nums)

        model.load_weights('./ckpt/weights_SSD300.hdf5', by_name=True)

        # 循环读取图片进行多个图片输出检测
        feature = []
        images = []
        for pic_name in os.listdir("./image/"):
            img_path = os.path.join("./image/", pic_name)
            print(img_path)
            # 读取图片
            # 转换成数组
            # 模型输入
            img = load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
            img = img_to_array(img)
            feature.append(img)

            images.append(imread(img_path))
            # 处理图片数据,ndarray数组输入
            inputs = preprocess_input(np.array(feature))
        # 预测
        preds = model.predict(inputs, batch_size=1, verbose=1)
        print(preds)
        # 定义BBox工具
        bbox_util = BBoxUtility(self.classes_nums)
        # 使用非最大抑制算法过滤
        results = bbox_util.detection_out(preds)
        print(results[0].shape, results[1].shape)
        return images, results

    def tag_picture(self, images, results):
        """
        对图片预测结果画图显示
        :param images:
        :param results:
        :return:
        """

        for i, img in enumerate(images):
            # 解析输出结果,每张图片的标签,置信度和位置
            pre_label = results[i][:, 0]
            pre_conf = results[i][:, 1]
            pre_xmin = results[i][:, 2]
            pre_ymin = results[i][:, 3]
            pre_xmax = results[i][:, 4]
            pre_ymax = results[i][:, 5]
            print("label:{}, probability:{}, xmin:{}, ymin:{}, xmax:{}, ymax:{}".
                  format(pre_label, pre_conf, pre_xmin, pre_ymin, pre_xmax, pre_ymax))

            # 过滤置信度低的结果
            top_indices = [i for i, conf in enumerate(pre_conf) if conf >= 0.6]
            top_conf = pre_conf[top_indices]
            top_label_indices = pre_label[top_indices].tolist()
            top_xmin = pre_xmin[top_indices]
            top_ymin = pre_ymin[top_indices]
            top_xmax = pre_xmax[top_indices]
            top_ymax = pre_ymax[top_indices]

            # 定义21中颜色,显示图片
            # currentAxis增加图中文本显示和标记显示
            colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist()
            plt.imshow(img / 255.)
            currentAxis = plt.gca()

            for i in range(top_conf.shape[0]):
                xmin = int(round(top_xmin[i] * img.shape[1]))
                ymin = int(round(top_ymin[i] * img.shape[0]))
                xmax = int(round(top_xmax[i] * img.shape[1]))
                ymax = int(round(top_ymax[i] * img.shape[0]))

                # 获取该图片预测概率,名称,定义显示颜色
                score = top_conf[i]
                label = int(top_label_indices[i])
                label_name = self.classes_name[label - 1]
                display_txt = '{:0.2f}, {}'.format(score, label_name)
                coords = (xmin, ymin), xmax - xmin + 1, ymax - ymin + 1
                color = colors[label]
                # 显示方框
                currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
                # 左上角显示概率以及名称
                currentAxis.text(xmin, ymin, display_txt, bbox={'facecolor': color, 'alpha': 0.5})

            plt.show()


if __name__ == '__main__':
    ssd = SSDTrain()
    images, results = ssd.test()
    ssd.tag_picture(images, results)

猜你喜欢

转载自blog.csdn.net/qq_31784189/article/details/112723802
今日推荐