yolo源码解析(二)

五 读取数据pascal_voc.py文件解析

我们在YOLENet类中定义了两个占位符,一个是输入图片占位符,一个是图片对应的标签占位符,如下:

复制代码
        #输入图片占位符 [NONE,image_size,image_size,3]
        self.images = tf.placeholder(
            tf.float32, [None, self.image_size, self.image_size, 3],
            name='images')
        #设置标签占位符 [None,S,S,5+C]  即[None,7,7,25]
        self.labels = tf.placeholder(
            tf.float32,
            [None, self.cell_size, self.cell_size, 5 + self.num_class])        
复制代码

而pascal_voc.py文件的目的就是为了准备数据,赋值给占位符。在pascal_voc.py文件中定义了一个pascal_voc,该类包含了类初始化函数(__init__()),准备数据函数(prepare()),读取batch大小的图片以及图片对应的标签(get())等函数。

复制代码
import os
import xml.etree.ElementTree as ET
import numpy as np
import cv2
import pickle
import copy
import yolo.config as cfg


'''
VOC2012数据集处理
'''

class pascal_voc(object):
复制代码

1、类初始化函数

复制代码
'''
VOC2012数据集处理
'''

class pascal_voc(object):
    '''
    VOC2012数据集处理的类,主要用来获取训练集图片文件,以及生成对应的标签文件
    '''
    def __init__(self, phase, rebuild=False):
        '''
        准备训练或者测试的数据
        
        args:
            phase:传入字符串 'train':表示训练
                              'test':测试
            rebuild:是否重新创建数据集的标签文件,保存在缓存文件夹下
        '''
        #VOCdevkit文件夹路径
        self.devkil_path = os.path.join(cfg.PASCAL_PATH, 'VOCdevkit')
        #VOC2012文件夹路径
        self.data_path = os.path.join(self.devkil_path, 'VOC2012')
        #catch文件所在路径
        self.cache_path = cfg.CACHE_PATH
        #批大小
        self.batch_size = cfg.BATCH_SIZE
        #图像大小
        self.image_size = cfg.IMAGE_SIZE
        #单元格大小S
        self.cell_size = cfg.CELL_SIZE
        #VOC 2012数据集类别名
        self.classes = cfg.CLASSES
        #类别名->索引的dict
        self.class_to_ind = dict(zip(self.classes, range(len(self.classes))))
        ##图片是否采用水平镜像扩充训练集?
        self.flipped = cfg.FLIPPED
        #训练或测试?
        self.phase = phase
        #是否重新创建数据集标签文件
        self.rebuild = rebuild
        #从gt_labels加载数据,cursor表明当前读取到第几个
        self.cursor = 0
        #存放当前训练的轮数
        self.epoch = 1
        #存放数据集的标签 是一个list 每一个元素都是一个dict,对应一个图片 
        #如果我们在配置文件中指定flipped=True,则数据集会扩充一倍,每一张原始图片都有一个水平对称的镜像文件
        #      imname:图片路径 
        #      label:图片标签
        #      flipped:图片水平镜像?
        self.gt_labels = None
        #加载数据集标签  初始化gt_labels
        self.prepare()
复制代码

2、prepare()所有数据准备函数

prepare()函数调用load_labels()函数,加载所有数据集的标签,保存在遍历gt_labels集合中,如果在配置文件中指定了水平镜像,则追加一倍的训练数据集。

复制代码
    def prepare(self):
        '''
        初始化数据集的标签,保存在变量gt_labels中
        
        return:
            gt_labels:返回数据集的标签 是一个list  每一个元素对应一张图片,是一个dict                       
                                     imname:图片文件路径
                                     label:图片文件对应的标签 [7,7,25]的矩阵
                                     flipped:是否使用水平镜像? 设置为False
        '''
        #加载数据集的标签
        gt_labels = self.load_labels()
        #如果水平镜像,则追加一倍的训练数据集
        if self.flipped:
            print('Appending horizontally-flipped training examples ...')
            #深度拷贝
            gt_labels_cp = copy.deepcopy(gt_labels)
            #遍历每一个图片标签
            for idx in range(len(gt_labels_cp)):
                #设置flipped属性为True
                gt_labels_cp[idx]['flipped'] = True
                #目标所在格子也进行水平镜像 [7,7,25]
                gt_labels_cp[idx]['label'] =\
                    gt_labels_cp[idx]['label'][:, ::-1, :]
                for i in range(self.cell_size):
                    for j in range(self.cell_size):
                        #置信度==1,表示这个格子有目标
                        if gt_labels_cp[idx]['label'][i, j, 0] == 1:
                            #中心的x坐标水平镜像
                            gt_labels_cp[idx]['label'][i, j, 1] = \
                                self.image_size - 1 -\
                                gt_labels_cp[idx]['label'][i, j, 1]
            #追加数据集的标签   后面的是由原数据集标签扩充的水平镜像数据集标签
            gt_labels += gt_labels_cp
        #打乱数据集的标签
        np.random.shuffle(gt_labels)
        self.gt_labels = gt_labels
        return gt_labels
复制代码

3、get()批量数据读取函数

get()函数用在训练的时候,每次从gt_labels集合随机读取batch大小的图片以及图片对应的标签。

复制代码
    def get(self):
        '''
        加载数据集 每次读取batch大小的图片以及图片对应的标签
        
        return:
            images:读取到的图片数据 [45,448,448,3]
            labels:对应的图片标签 [45,7,7,25]
        '''
        #[45,448,448,3]
        images = np.zeros(
            (self.batch_size, self.image_size, self.image_size, 3))
        #[45,7,7,25]
        labels = np.zeros(
            (self.batch_size, self.cell_size, self.cell_size, 25))
        count = 0
        #一次加载batch_size个图片数据
        while count < self.batch_size:
            #获取图片路径
            imname = self.gt_labels[self.cursor]['imname']
            #是否使用水平镜像?
            flipped = self.gt_labels[self.cursor]['flipped']
            #读取图片数据
            images[count, :, :, :] = self.image_read(imname, flipped)
            #读取图片标签
            labels[count, :, :, :] = self.gt_labels[self.cursor]['label']
            count += 1
            self.cursor += 1
            #如果读取完一轮数据,则当前cursor置为0,当前训练轮数+1
            if self.cursor >= len(self.gt_labels):
                #打乱数据集
                np.random.shuffle(self.gt_labels)
                self.cursor = 0                
                self.epoch += 1
        return images, labels
复制代码

4、image_read()函数读取图片

图片读取函数,先读取图片,然后缩放,转换为RGB格式,再对数据进行归一化处理。

复制代码
    def image_read(self, imname, flipped=False):
        '''
        读取图片
        
        args:
            imname:图片路径
            flipped:图片是否水平镜像处理? 
            
        return:
            image:图片数据 [448,448,3]
        '''
        #读取图片数据
        image = cv2.imread(imname)
        #缩放处理
        image = cv2.resize(image, (self.image_size, self.image_size))
        #BGR->RGB  uint->float32
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        #归一化处理 [-1.0,1.0]
        image = (image / 255.0) * 2.0 - 1.0
        #宽倒序  即水平镜像
        if flipped:
            image = image[:, ::-1, :]
        return image
复制代码

5、load_labels()加载标签函数

复制代码
    def load_labels(self):
        '''
        加载数据集标签
        
        return:
            gt_labels:是一个list  每一个元素对应一张图片,是一个dict                       
                                     imname:图片文件路径
                                     label:图片文件对应的标签 [7,7,25]的矩阵
                                     flipped:是否使用水平镜像? 设置为False   
        '''
        #缓冲文件名:即用来保存数据集标签的文件
        cache_file = os.path.join(
            self.cache_path, 'pascal_' + self.phase + '_gt_labels.pkl')

        #文件存在,且不重新创建则直接读取
        if os.path.isfile(cache_file) and not self.rebuild:
            print('Loading gt_labels from: ' + cache_file)
            with open(cache_file, 'rb') as f:
                gt_labels = pickle.load(f)
            return gt_labels

        print('Processing gt_labels from: ' + self.data_path)

        #如果缓冲文件目录不存在,创建
        if not os.path.exists(self.cache_path):
            os.makedirs(self.cache_path)
            
        #获取训练测试集的数据文件名
        if self.phase == 'train':
            txtname = os.path.join(
                self.data_path, 'ImageSets', 'Main', 'trainval.txt')
        #获取测试集的数据文件名
        else:
            txtname = os.path.join(
                self.data_path, 'ImageSets', 'Main', 'test.txt')
        with open(txtname, 'r') as f:
            self.image_index = [x.strip() for x in f.readlines()]

        #存放图片的标签,图片路径,是否使用水平镜像?
        gt_labels = []
        #遍历每一张图片的信息
        for index in self.image_index:
            #读取每一张图片的标签label [7,7,25]
            label, num = self.load_pascal_annotation(index)
            if num == 0:
                continue
            #图片文件路径
            imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
            #保存该图片的信息
            gt_labels.append({'imname': imname,
                              'label': label,
                              'flipped': False})
        print('Saving gt_labels to: ' + cache_file)
        #保存
        with open(cache_file, 'wb') as f:
            pickle.dump(gt_labels, f)
        return gt_labels
复制代码

6、load_pascal_annotation()函数

复制代码
    def load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        
        args:
            index:图片文件的index
            
        return :
            label:标签 [7,7,25] 
                      0:1:置信度,表示这个地方是否有目标
                      1:5:目标边界框  目标中心,宽度和高度(这里是实际值,没有归一化)
                      5:25:目标的类别
            len(objs):objs对象长度
        """
        #获取图片文件名路径
        imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
        #读取数据
        im = cv2.imread(imname)
        #宽和高缩放比例
        h_ratio = 1.0 * self.image_size / im.shape[0]
        w_ratio = 1.0 * self.image_size / im.shape[1]
        # im = cv2.resize(im, [self.image_size, self.image_size])
        #用于保存图片文件的标签
        label = np.zeros((self.cell_size, self.cell_size, 25))
        #图片文件的标注xml文件
        filename = os.path.join(self.data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')

        for obj in objs:
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based  当图片缩放到image_size时,边界框也进行同比例缩放
            x1 = max(min((float(bbox.find('xmin').text) - 1) * w_ratio, self.image_size - 1), 0)
            y1 = max(min((float(bbox.find('ymin').text) - 1) * h_ratio, self.image_size - 1), 0)
            x2 = max(min((float(bbox.find('xmax').text) - 1) * w_ratio, self.image_size - 1), 0)
            y2 = max(min((float(bbox.find('ymax').text) - 1) * h_ratio, self.image_size - 1), 0)
            #根据图片的分类名 ->类别index 转换
            cls_ind = self.class_to_ind[obj.find('name').text.lower().strip()]
            #计算边框中心点x,y,w,h(没有归一化)
            boxes = [(x2 + x1) / 2.0, (y2 + y1) / 2.0, x2 - x1, y2 - y1]
            #计算当前物体的中心在哪个格子中
            x_ind = int(boxes[0] * self.cell_size / self.image_size)
            y_ind = int(boxes[1] * self.cell_size / self.image_size)
            #表明该图片已经初始化过了
            if label[y_ind, x_ind, 0] == 1:
                continue
            #置信度,表示这个地方有物体
            label[y_ind, x_ind, 0] = 1
            #物体边界框
            label[y_ind, x_ind, 1:5] = boxes
            #物体的类别
            label[y_ind, x_ind, 5 + cls_ind] = 1

        return label, len(objs)
复制代码

六 训练网络

模型训练包含于train.py文件,Solver类的train()方法之中,训练部分只需要看懂了初始化参数,整个结构就很清晰了。

复制代码
import os
import argparse
import datetime
import tensorflow as tf
import yolo.config as cfg
from yolo.yolo_net import YOLONet
from utils.timer import Timer
from utils.pascal_voc import pascal_voc

slim = tf.contrib.slim

'''
用来训练YOLO网络模型
'''

class Solver(object):
    '''
    求解器的类,用于训练YOLO网络
    '''
复制代码

1、类初始化函数

复制代码
   def __init__(self, net, data):
        '''
        构造函数,加载训练参数
        
        args:
            net:YOLONet对象
            data:pascal_voc对象
        '''
        #yolo网络
        self.net = net
        #voc2012数据处理
        self.data = data
        #检查点文件路径
        self.weights_file = cfg.WEIGHTS_FILE
        #训练最大迭代次数
        self.max_iter = cfg.MAX_ITER
        #初始学习率
        self.initial_learning_rate = cfg.LEARNING_RATE
        ##退化学习率衰减步数
        self.decay_steps = cfg.DECAY_STEPS
        #衰减率
        self.decay_rate = cfg.DECAY_RATE
        self.staircase = cfg.STAIRCASE
        ##日志文件保存间隔步
        self.summary_iter = cfg.SUMMARY_ITER
        ##模型保存间隔步
        self.save_iter = cfg.SAVE_ITER
        
        #输出文件夹路径
        self.output_dir = os.path.join(
            cfg.OUTPUT_DIR, datetime.datetime.now().strftime('%Y_%m_%d_%H_%M'))
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        #保存配置信息
        self.save_cfg()
        #指定保存的张量 这里指定所有变量
        self.variable_to_restore = tf.global_variables()        
        self.saver = tf.train.Saver(self.variable_to_restore, max_to_keep=None)
        #指定保存的模型名称
        self.ckpt_file = os.path.join(self.output_dir, 'yolo.cpkt')
        #合并所有的summary
        self.summary_op = tf.summary.merge_all()
        #创建writer,指定日志文件路径,用于写日志文件
        self.writer = tf.summary.FileWriter(self.output_dir, flush_secs=60)

        #创建变量,保存当前迭代次数
        self.global_step = tf.train.create_global_step()
        #退化学习率
        self.learning_rate = tf.train.exponential_decay(
            self.initial_learning_rate, self.global_step, self.decay_steps,
            self.decay_rate, self.staircase, name='learning_rate')
        #创建求解器
        self.optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=self.learning_rate)
        # create_train_op that ensures that when we evaluate it to get the loss,
        # the update_ops are done and the gradient updates are computed.
        self.train_op = slim.learning.create_train_op(
            self.net.total_loss, self.optimizer, global_step=self.global_step)

        #设置GPU使用资源
        gpu_options = tf.GPUOptions()
        #按需分配GPU使用的资源
        config = tf.ConfigProto(gpu_options=gpu_options)
        self.sess = tf.Session(config=config)
        
        #运行图之前,初始化变量
        self.sess.run(tf.global_variables_initializer())

        #恢复模型
        if self.weights_file is not None:
            print('Restoring weights from: ' + self.weights_file)
            self.saver.restore(self.sess, self.weights_file)

        #将图写入日志文件
        self.writer.add_graph(self.sess.graph)
复制代码

 2、train()训练函数

复制代码
 def train(self):
        '''
        开始训练
        '''
        #训练时间
        train_timer = Timer()
        #数据集加载时间
        load_timer = Timer()

        #开始迭代
        for step in range(1, self.max_iter + 1):
            #计算每次迭代加载数据的起始时间
            load_timer.tic()
            #加载数据集 每次读取batch大小的图片以及图片对应的标签
            images, labels = self.data.get()
            #计算这次迭代加载数据集所使用的时间
            load_timer.toc()
            
            feed_dict = {self.net.images: images,
                         self.net.labels: labels}

            #迭代summary_iter次,保存一次日志文件,迭代summary_iter*10次,输出一次的迭代信息
            if step % self.summary_iter == 0:
                if step % (self.summary_iter * 10) == 0:
                    #计算每次迭代训练的起始时间
                    train_timer.tic()
                    loss = 0.0001  
                    #开始迭代训练,每一次迭代后global_step自加1
                    summary_str, loss, _ = self.sess.run(
                        [self.summary_op, self.net.total_loss, self.train_op],
                        feed_dict=feed_dict)
                    #输出信息
                    log_str = '{} Epoch: {}, Step: {}, Learning rate: {}, Loss: {:5.3f}\nSpeed: {:.3f}s/iter,Load: {:.3f}s/iter, Remain: {}'.format(
                        datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
                        self.data.epoch,
                        int(step),
                        round(self.learning_rate.eval(session=self.sess), 6),
                        loss,
                        train_timer.average_time,
                        load_timer.average_time,
                        train_timer.remain(step, self.max_iter))
                    print(log_str)

                else:
                    #计算每次迭代训练的起始时间
                    train_timer.tic()           
                    #开始迭代训练,每一次迭代后global_step自加1
                    summary_str, _ = self.sess.run(
                        [self.summary_op, self.train_op],
                        feed_dict=feed_dict)  
                    #计算这次迭代训练所使用的时间
                    train_timer.toc()
                    
                #将summary写入文件
                self.writer.add_summary(summary_str, step)

            else:
                #计算每次迭代训练的起始时间
                train_timer.tic()
                #开始迭代训练,每一次迭代后global_step自加1
                self.sess.run(self.train_op, feed_dict=feed_dict)
                #计算这次迭代训练所使用的时间
                train_timer.toc()

            #没迭代save_iter次,保存一次模型
            if step % self.save_iter == 0:
                print('{} Saving checkpoint file to: {}'.format(
                    datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
                    self.output_dir))
                self.saver.save(
                    self.sess, self.ckpt_file, global_step=self.global_step)
复制代码

3、保存配置参数

复制代码
    def save_cfg(self):
        '''
        保存配置信息
        '''
        with open(os.path.join(self.output_dir, 'config.txt'), 'w') as f:
            cfg_dict = cfg.__dict__
            for key in sorted(cfg_dict.keys()):
                if key[0].isupper():
                    cfg_str = '{}: {}\n'.format(key, cfg_dict[key])
                    f.write(cfg_str)
复制代码

train.py文件除了上面介绍的求解器Solver这个类外,还包含了两个函数,一个是update_config_paths()函数,这个函数主要使用了设定数据集路径,以及检查点文件路径。

复制代码
def update_config_paths(data_dir, weights_file):
    '''
    数据集路径,和模型检查点文件路径
    
    args:
        data_dir:数据文件夹  数据集放在pascal_voc目录下  
        weights_file:检查点文件名 该文件放在数据集目录下的weights文件夹下
        
    '''
    cfg.DATA_PATH = data_dir                                                   #数据所在文件夹
    cfg.PASCAL_PATH = os.path.join(data_dir, 'pascal_voc')                     #VOC2012数据所在文件夹
    cfg.CACHE_PATH = os.path.join(cfg.PASCAL_PATH, 'cache')                    #保存生成的数据集标签缓冲文件所在文件夹
    cfg.OUTPUT_DIR = os.path.join(cfg.PASCAL_PATH, 'output')                   #保存生成的网络模型和日志文件所在的文件夹
    cfg.WEIGHTS_DIR = os.path.join(cfg.PASCAL_PATH, 'weights')                 #检查点文件所在的目录

    cfg.WEIGHTS_FILE = os.path.join(cfg.WEIGHTS_DIR, weights_file)
复制代码

我们主要来说一下另一个函数main()函数,先解析命令行参数,然后创建YOLONet、pascal_voc、Solver对象,最后开始训练。

复制代码
def main():
    #创建一个解析器对象,并告诉它将会有些什么参数。当程序运行时,该解析器就可以用于处理命令行参数。
    #https://www.cnblogs.com/lovemyspring/p/3214598.html
    parser = argparse.ArgumentParser()
    #定义参数
    parser.add_argument('--weights', default="YOLO_small.ckpt", type=str)    #权重文件名
    parser.add_argument('--data_dir', default="data", type=str)              #数据集路径
    parser.add_argument('--threshold', default=0.2, type=float)
    parser.add_argument('--iou_threshold', default=0.5, type=float)
    parser.add_argument('--gpu', default='', type=str)
    #定义了所有参数之后,你就可以给 parse_args() 传递一组参数字符串来解析命令行。默认情况下,参数是从 sys.argv[1:] 中获取
    #parse_args() 的返回值是一个命名空间,包含传递给命令的参数。该对象将参数保存其属性
    args = parser.parse_args()

    #判断是否是使用gpu
    if args.gpu is not None:
        cfg.GPU = args.gpu

    #设定数据集路径,以及检查点文件路径
    if args.data_dir != cfg.DATA_PATH  and args.data_dir is not None:
        update_config_paths(args.data_dir, args.weights)

    #设置环境变量
    os.environ['CUDA_VISIBLE_DEVICES'] = cfg.GPU

    #创建YOLO网络对象
    yolo = YOLONet()
    #数据集对象
    pascal = pascal_voc('train')
    #求解器对象
    solver = Solver(yolo, pascal)

    print('Start training ...')
    #开始训练
    solver.train()
    print('Done training.')
复制代码

我们执行如下代码,开始训练网络:

if __name__ == '__main__':
    tf.reset_default_graph()
    # python train.py --weights YOLO_small.ckpt --gpu 0
    main()

猜你喜欢

转载自www.cnblogs.com/sddai/p/10288096.html