TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/guyuealian/article/details/85106012

TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制


    之前写了一篇博客,关于《Tensorflow生成自己的图片数据集TFrecord》,项目做多了,你会发现将数据转为TFrecord格式,实在是太麻烦了,灵活性太差!后面就总结一下TensorFlow数据读取机制,主要还是介绍tf.data.Dataset的数据读取机制(Pipeline机制)。

    TensorFlow数据读取机制主要是两种方法:

(1)一种是使用文件队列方式,如使用slice_input_producer和string_input_producer;这种方法既可以将数据转存为TFrecord数据格式,也可以直接读取文件图片数据,当然转存为TFrecord数据格式进行读取,会更高效点

(2)另一种是TensorFlow 1.4版本后出现的tf.data.Dataset的数据读取机制(Pipeline机制)。这是TensorFlow强烈推荐的方式,是一种更高效的读取方式。使用tf.data.Dataset模块的pipline机制,可实现CPU多线程处理输入的数据,如读取图片和图片的一些的预处理,这样GPU可以专注于训练过程,而CPU去准备数据。

      本博客Github源码:https://github.com/PanJinquan/tensorflow-learning-tutorials ->tf_record_demo文件夹(觉得可以,还请给个“Star”哦

     之前专门写了一篇博客关于《 Tensorflow生成自己的图片数据集TFrecords(支持多标签label)https://blog.csdn.net/guyuealian/article/details/80857228,主要实现的是使用自己的数据集制作TensorFlow的TFrecord数据格式。


目录

目录

TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制

1. 文件队列读取方式:slice_input_producer和string_input_producer

1.1.生成图片数据集TFrecords

(1)生成单个record文件 (单label)

(2)生成单个record文件 (多label)

(3)生成多个record文件的方法

1.2. 直接文件读取方式 

2.tf.data.Dataset数据读取机制:Pipeline机制

prefetch(必须放在最后)

map

repeat

实例代码1:dataset.make_initializable_iterator()

实例代码2:dataset.make_one_shot_iterator()

实例代码3:产生用于训练的图和label

实例代码4:产生用于训练的原始图和target目标图

实例代码5: tf.data.Dataset.from_generator

3. 用Python循环产生批量数据batch

4.参考资料:



1. 文件队列读取方式:slice_input_producer和string_input_producer

    TensorFlow可以采用tf.train.slice_input_producer或者tf.train.string_input_producer两种方法产生文件队列,其区别就是:前者是输入是tensor_list,因此,可以将多个list组合成一个tensorlist作为输入;而后者只能是一个string_tensor了,例子如下:

    image_dir ='path/to/image_dir/*.jpg'
    image_list = glob.glob(image_dir)
    label_list=...
    image_list = tf.convert_to_tensor(image_list, dtype=tf.string)
    # 可以将image_list,label_list多个list组合成一个tensor_list
    image_que, label_que = tf.train.slice_input_producer([image_list,label_list], num_epochs=1)
    # 只能时string_tensor,所以不能组合多个list
    image = tf.train.string_input_producer(image_list, num_epochs=1)

1.1.生成图片数据集TFrecords

    假设train.txt保存图片的路径和标签信息,如下,以空格分割,第一项的图片的路径名,第二项是图片对应的labels

dog/1.jpg 0
dog/2.jpg 0
dog/3.jpg 0
dog/4.jpg 0
cat/1.jpg 1
cat/2.jpg 1
cat/3.jpg 1
cat/4.jpg 1

    这里提供三种方法将图像数据转存为TFrecords数据格式,当然也包含TFrecords解析的方法,详细的用法都会在函数参数说明,已经封装了很简单了,你只需要改变你图片的路径就可以。

  • 生成单个record文件 (单label)

    这种方法会将所有图片数据和单labels转存为一个record文件,合适单labels小批量的数据

  • 生成单个record文件 (多label)

    这种方法将所有图片数据和多个labels转存为一个record文件,合适多labels的小批量的数据

  • 生成多个record文件的方法

    这种方法将图片数据和labels,切分一个batch_size的大小,并转存为多个record文件,合适大批量的数据

(1)生成单个record文件 (单label)

     下面是封装好的py文件,可以直接生成单个record文件 ,当然这里假设只有一个label情况。其中get_batch_images函数会产生一个batch的数据,这个batch的数据就可以用于CNN的网络训练的数据。

# -*-coding: utf-8 -*-
"""
    @Project: create_tfrecord
    @File   : create_tfrecord.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-07-27 17:19:54
    @desc   : 将图片数据保存为单个tfrecord文件
"""

##########################################################################

import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import random
from PIL import Image


##########################################################################
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成实数型的属性
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def get_example_nums(tf_records_filenames):
    '''
    统计tf_records图像的个数(example)个数
    :param tf_records_filenames: tf_records文件路径
    :return:
    '''
    nums= 0
    for record in tf.python_io.tf_record_iterator(tf_records_filenames):
        nums += 1
    return nums

def show_image(title,image):
    '''
    显示图片
    :param title: 图像标题
    :param image: 图像的数据
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')    # 关掉坐标轴为 off
    plt.title(title)  # 图像题目
    plt.show()

def load_labels_file(filename,labels_num=1,shuffle=False):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2,如:test_image/1.jpg 0 2
    :param filename:
    :param labels_num :labels个数
    :param shuffle :是否打乱顺序
    :return:images type->list
    :return:labels type->list
    '''
    images=[]
    labels=[]
    with open(filename) as f:
        lines_list=f.readlines()
        if shuffle:
            random.shuffle(lines_list)

        for lines in lines_list:
            line=lines.rstrip().split(' ')
            label=[]
            for i in range(labels_num):
                label.append(int(line[i+1]))
            images.append(line[0])
            labels.append(label)
    return images,labels

def read_image(filename, resize_height, resize_width,normalization=False):
    '''
    读取图片数据,默认返回的是uint8,[0,255]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param normalization:是否归一化到[0.,1.0]
    :return: 返回的图片数据
    '''

    bgr_image = cv2.imread(filename)
    if len(bgr_image.shape)==2:#若是灰度图则转为三通道
        print("Warning:gray image",filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)

    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#将BGR转为RGB
    # show_image(filename,rgb_image)
    # rgb_image=Image.open(filename)
    if resize_height>0 and resize_width>0:
        rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
    rgb_image=np.asanyarray(rgb_image)
    if normalization:
        # 不能写成:rgb_image=rgb_image/255
        rgb_image=rgb_image/255.0
    # show_image("src resize image",image)
    return rgb_image


def get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):
    '''
    :param images:图像
    :param labels:标签
    :param batch_size:
    :param labels_nums:标签个数
    :param one_hot:是否将labels转为one_hot的形式
    :param shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False
    :return:返回batch的images和labels
    '''
    min_after_dequeue = 200
    capacity = min_after_dequeue + 3 * batch_size  # 保证capacity必须大于min_after_dequeue参数值
    if shuffle:
        images_batch, labels_batch = tf.train.shuffle_batch([images,labels],
                                                                    batch_size=batch_size,
                                                                    capacity=capacity,
                                                                    min_after_dequeue=min_after_dequeue,
                                                                    num_threads=num_threads)
    else:
        images_batch, labels_batch = tf.train.batch([images,labels],
                                                        batch_size=batch_size,
                                                        capacity=capacity,
                                                        num_threads=num_threads)
    if one_hot:
        labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
    return images_batch,labels_batch

def read_records(filename,resize_height, resize_width,type=None):
    '''
    解析record文件:源文件的图像数据是RGB,uint8,[0,255],一般作为训练数据时,需要归一化到[0,1]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param type:选择图像数据的返回类型
         None:默认将uint8-[0,255]转为float32-[0,255]
         normalization:归一化float32-[0,1]
         standardization:标准化float32-[0,1],再减均值中心化
    :return:
    '''
    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        }
    )
    tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#获得图像原始的数据

    tf_height = features['height']
    tf_width = features['width']
    tf_depth = features['depth']
    tf_label = tf.cast(features['label'], tf.int32)
    # PS:恢复原始图像数据,reshape的大小必须与保存之前的图像shape一致,否则出错
    # tf_image=tf.reshape(tf_image, [-1])    # 转换为行向量
    tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 设置图像的维度

    # 恢复数据后,才可以对图像进行resize_images:输入uint->输出float32
    # tf_image=tf.image.resize_images(tf_image,[224, 224])

    # [3]数据类型处理
    # 存储的图像类型为uint8,tensorflow训练时数据必须是tf.float32
    if type is None:
        tf_image = tf.cast(tf_image, tf.float32)
    elif type == 'normalization':  # [1]若需要归一化请使用:
        # 仅当输入数据是uint8,才会归一化[0,255]
        # tf_image = tf.cast(tf_image, dtype=tf.uint8)
        # tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)
        tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0)  # 归一化
    elif type == 'standardization':  # 标准化
        # tf_image = tf.cast(tf_image, dtype=tf.uint8)
        # tf_image = tf.image.per_image_standardization(tf_image)  # 标准化(减均值除方差)
        # 若需要归一化,且中心化,假设均值为0.5,请使用:
        tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5  # 中心化

    # 这里仅仅返回图像和标签
    # return tf_image, tf_height,tf_width,tf_depth,tf_label
    return tf_image,tf_label


def create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5):
    '''
    实现将图像原始数据,label,长,宽等信息保存为record文件
    注意:读取的图像数据默认是uint8,再转为tf的字符串型BytesList保存,解析请需要根据需要转换类型
    :param image_dir:原始图像的目录
    :param file:输入保存图片信息的txt文件(image_dir+file构成图片的路径)
    :param output_record_dir:保存record文件的路径
    :param resize_height:
    :param resize_width:
    PS:当resize_height或者resize_width=0是,不执行resize
    :param shuffle:是否打乱顺序
    :param log:log信息打印间隔
    '''
    # 加载文件,仅获取一个label
    images_list, labels_list=load_labels_file(file,1,shuffle)

    writer = tf.python_io.TFRecordWriter(output_record_dir)
    for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):
        image_path=os.path.join(image_dir,images_list[i])
        if not os.path.exists(image_path):
            print('Err:no image',image_path)
            continue
        image = read_image(image_path, resize_height, resize_width)
        image_raw = image.tostring()
        if i%log==0 or i==len(images_list)-1:
            print('------------processing:%d-th------------' % (i))
            print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))
        # 这里仅保存一个label,多label适当增加"'label': _int64_feature(label)"项
        label=labels[0]
        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw': _bytes_feature(image_raw),
            'height': _int64_feature(image.shape[0]),
            'width': _int64_feature(image.shape[1]),
            'depth': _int64_feature(image.shape[2]),
            'label': _int64_feature(label)
        }))
        writer.write(example.SerializeToString())
    writer.close()

def disp_records(record_file,resize_height, resize_width,show_nums=4):
    '''
    解析record文件,并显示show_nums张图片,主要用于验证生成record文件是否成功
    :param tfrecord_file: record文件路径
    :return:
    '''
    # 读取record函数
    tf_image, tf_label = read_records(record_file,resize_height,resize_width,type='normalization')
    # 显示前4个图片
    init_op = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(show_nums):
            image,label = sess.run([tf_image,tf_label])  # 在会话中取出image和label
            # image = tf_image.eval()
            # 直接从record解析的image是一个向量,需要reshape显示
            # image = image.reshape([height,width,depth])
            print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))
            # pilimg = Image.fromarray(np.asarray(image_eval_reshape))
            # pilimg.show()
            show_image("image:%d"%(label),image)
        coord.request_stop()
        coord.join(threads)


def batch_test(record_file,resize_height, resize_width):
    '''
    :param record_file: record文件路径
    :param resize_height:
    :param resize_width:
    :return:
    :PS:image_batch, label_batch一般作为网络的输入
    '''
    # 读取record函数
    tf_image,tf_label = read_records(record_file,resize_height,resize_width,type='normalization')
    image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=5,one_hot=False,shuffle=False)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:  # 开始一个会话
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(4):
            # 在会话中取出images和labels
            images, labels = sess.run([image_batch, label_batch])
            # 这里仅显示每个batch里第一张图片
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))

        # 停止所有线程
        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    # 参数设置

    resize_height = 224  # 指定存储图片高度
    resize_width = 224  # 指定存储图片宽度
    shuffle=True
    log=5
    # 产生train.record文件
    image_dir='dataset/train'
    train_labels = 'dataset/train.txt'  # 图片路径
    train_record_output = 'dataset/record/train.tfrecords'
    create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
    train_nums=get_example_nums(train_record_output)
    print("save train example nums={}".format(train_nums))

    # 产生val.record文件
    image_dir='dataset/val'
    val_labels = 'dataset/val.txt'  # 图片路径
    val_record_output = 'dataset/record/val.tfrecords'
    create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
    val_nums=get_example_nums(val_record_output)
    print("save val example nums={}".format(val_nums))

    # 测试显示函数
    # disp_records(train_record_output,resize_height, resize_width)
    batch_test(train_record_output,resize_height, resize_width)

(2)生成单个record文件 (多label)

    对于多label的情况,你可以在单label的基础上增加多个“label': tf.FixedLenFeature([], tf.int64)“,但每次label个数不一样时,都需要修改,挺麻烦的。这里提供一个方法:label数据也可以像图像数据那样,转为string类型来保存:labels_raw = np.asanyarray(labels,dtype=np.float32).tostring() ,解析时也跟图像数据一样进行解析:tf_label = tf.decode_raw(features['labels'],tf.float32) ,这样,不管多少个label,我们都可以保存为record文件了:

   多label的TXT文件:

0.jpg 0.33 0.55
1.jpg 0.42 0.73
2.jpg 0.16 0.75
3.jpg 0.78 0.66
4.jpg 0.46 0.59
5.jpg 0.46 0.09
6.jpg 0.89 0.93
7.jpg 0.42 0.82
8.jpg 0.39 0.76
9.jpg 0.46 0.40
# -*-coding: utf-8 -*-
"""
    @Project: create_tfrecord
    @File   : create_tf_record_multi_label.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-07-27 17:19:54
    @desc   : 将图片数据,多label,保存为单个tfrecord文件
"""

##########################################################################

import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import random
from PIL import Image


##########################################################################
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

# 生成字符串型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成实数型的属性
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def get_example_nums(tf_records_filenames):
    '''
    统计tf_records图像的个数(example)个数
    :param tf_records_filenames: tf_records文件路径
    :return:
    '''
    nums= 0
    for record in tf.python_io.tf_record_iterator(tf_records_filenames):
        nums += 1
    return nums

def show_image(title,image):
    '''
    显示图片
    :param title: 图像标题
    :param image: 图像的数据
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')    # 关掉坐标轴为 off
    plt.title(title)  # 图像题目
    plt.show()

def load_labels_file(filename,labels_num=1,shuffle=False):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2,如:test_image/1.jpg 0 2
    :param filename:
    :param labels_num :labels个数
    :param shuffle :是否打乱顺序
    :return:images type->list
    :return:labels type->list
    '''
    images=[]
    labels=[]
    with open(filename) as f:
        lines_list=f.readlines()
        if shuffle:
            random.shuffle(lines_list)

        for lines in lines_list:
            line=lines.rstrip().split(' ')
            label=[]
            for i in range(labels_num):
                label.append(float(line[i+1]))
            images.append(line[0])
            labels.append(label)
    return images,labels

def read_image(filename, resize_height, resize_width,normalization=False):
    '''
    读取图片数据,默认返回的是uint8,[0,255]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param normalization:是否归一化到[0.,1.0]
    :return: 返回的图片数据
    '''

    bgr_image = cv2.imread(filename)
    if len(bgr_image.shape)==2:#若是灰度图则转为三通道
        print("Warning:gray image",filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)

    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#将BGR转为RGB
    # show_image(filename,rgb_image)
    # rgb_image=Image.open(filename)
    if resize_height>0 and resize_width>0:
        rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
    rgb_image=np.asanyarray(rgb_image)
    if normalization:
        # 不能写成:rgb_image=rgb_image/255
        rgb_image=rgb_image/255.0
    # show_image("src resize image",image)
    return rgb_image


def get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):
    '''
    :param images:图像
    :param labels:标签
    :param batch_size:
    :param labels_nums:标签个数
    :param one_hot:是否将labels转为one_hot的形式
    :param shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False
    :return:返回batch的images和labels
    '''
    min_after_dequeue = 200
    capacity = min_after_dequeue + 3 * batch_size  # 保证capacity必须大于min_after_dequeue参数值
    if shuffle:
        images_batch, labels_batch = tf.train.shuffle_batch([images,labels],
                                                                    batch_size=batch_size,
                                                                    capacity=capacity,
                                                                    min_after_dequeue=min_after_dequeue,
                                                                    num_threads=num_threads)
    else:
        images_batch, labels_batch = tf.train.batch([images,labels],
                                                        batch_size=batch_size,
                                                        capacity=capacity,
                                                        num_threads=num_threads)
    if one_hot:
        labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
    return images_batch,labels_batch

def read_records(filename,resize_height, resize_width,type=None):
    '''
    解析record文件:源文件的图像数据是RGB,uint8,[0,255],一般作为训练数据时,需要归一化到[0,1]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param type:选择图像数据的返回类型
         None:默认将uint8-[0,255]转为float32-[0,255]
         normalization:归一化float32-[0,1]
         standardization:归一化float32-[0,1],再减均值中心化
    :return:
    '''
    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'labels': tf.FixedLenFeature([], tf.string)
        }
    )
    tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#获得图像原始的数据

    tf_height = features['height']
    tf_width = features['width']
    tf_depth = features['depth']
    # tf_label = tf.cast(features['labels'], tf.float32)
    tf_label = tf.decode_raw(features['labels'],tf.float32)

    # PS:恢复原始图像数据,reshape的大小必须与保存之前的图像shape一致,否则出错
    # tf_image=tf.reshape(tf_image, [-1])    # 转换为行向量
    tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 设置图像的维度

    tf_label=tf.reshape(tf_label, [2]) # 设置图像的维度


    # 恢复数据后,才可以对图像进行resize_images:输入uint->输出float32
    # tf_image=tf.image.resize_images(tf_image,[224, 224])

    # [3]数据类型处理
    # 存储的图像类型为uint8,tensorflow训练时数据必须是tf.float32
    if type is None:
        tf_image = tf.cast(tf_image, tf.float32)
    elif type == 'normalization':  # [1]若需要归一化请使用:
        # 仅当输入数据是uint8,才会归一化[0,255]
        # tf_image = tf.cast(tf_image, dtype=tf.uint8)
        # tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)
        tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0)  # 归一化
    elif type == 'standardization':  # 标准化
        # tf_image = tf.cast(tf_image, dtype=tf.uint8)
        # tf_image = tf.image.per_image_standardization(tf_image)  # 标准化(减均值除方差)
        # 若需要归一化,且中心化,假设均值为0.5,请使用:
        tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5  # 中心化

    # 这里仅仅返回图像和标签
    # return tf_image, tf_height,tf_width,tf_depth,tf_label
    return tf_image,tf_label


def create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5):
    '''
    实现将图像原始数据,label,长,宽等信息保存为record文件
    注意:读取的图像数据默认是uint8,再转为tf的字符串型BytesList保存,解析请需要根据需要转换类型
    :param image_dir:原始图像的目录
    :param file:输入保存图片信息的txt文件(image_dir+file构成图片的路径)
    :param output_record_dir:保存record文件的路径
    :param resize_height:
    :param resize_width:
    PS:当resize_height或者resize_width=0是,不执行resize
    :param shuffle:是否打乱顺序
    :param log:log信息打印间隔
    '''
    # 加载文件,仅获取一个label
    labels_num=2
    images_list, labels_list=load_labels_file(file,labels_num,shuffle)

    writer = tf.python_io.TFRecordWriter(output_record_dir)
    for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):
        image_path=os.path.join(image_dir,images_list[i])
        if not os.path.exists(image_path):
            print('Err:no image',image_path)
            continue
        image = read_image(image_path, resize_height, resize_width)
        image_raw = image.tostring()
        if i%log==0 or i==len(images_list)-1:
            print('------------processing:%d-th------------' % (i))
            print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))
        # 这里仅保存一个label,多label适当增加"'label': _int64_feature(label)"项
        # label=labels[0]
        # labels_raw="0.12,0,15"
        labels_raw = np.asanyarray(labels,dtype=np.float32).tostring()

        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw': _bytes_feature(image_raw),
            'height': _int64_feature(image.shape[0]),
            'width': _int64_feature(image.shape[1]),
            'depth': _int64_feature(image.shape[2]),
            'labels': _bytes_feature(labels_raw),

        }))
        writer.write(example.SerializeToString())
    writer.close()

def disp_records(record_file,resize_height, resize_width,show_nums=4):
    '''
    解析record文件,并显示show_nums张图片,主要用于验证生成record文件是否成功
    :param tfrecord_file: record文件路径
    :return:
    '''
    # 读取record函数
    tf_image, tf_label = read_records(record_file,resize_height,resize_width,type='normalization')
    # 显示前4个图片
    init_op = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(show_nums):
            image,label = sess.run([tf_image,tf_label])  # 在会话中取出image和label
            # image = tf_image.eval()
            # 直接从record解析的image是一个向量,需要reshape显示
            # image = image.reshape([height,width,depth])
            print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))
            # pilimg = Image.fromarray(np.asarray(image_eval_reshape))
            # pilimg.show()
            show_image("image:{}".format(label),image)
        coord.request_stop()
        coord.join(threads)


def batch_test(record_file,resize_height, resize_width):
    '''
    :param record_file: record文件路径
    :param resize_height:
    :param resize_width:
    :return:
    :PS:image_batch, label_batch一般作为网络的输入
    '''
    # 读取record函数
    tf_image,tf_label = read_records(record_file,resize_height,resize_width,type='normalization')
    image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=2,one_hot=False,shuffle=True)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:  # 开始一个会话
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(4):
            # 在会话中取出images和labels
            images, labels = sess.run([image_batch, label_batch])
            # 这里仅显示每个batch里第一张图片
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))

        # 停止所有线程
        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    # 参数设置

    resize_height = 224  # 指定存储图片高度
    resize_width = 224  # 指定存储图片宽度
    shuffle=True
    log=1000
    # 产生train.record文件
    image_dir='dataset_regression/images'
    train_labels = 'dataset_regression/train.txt'  # 图片路径
    train_record_output = 'dataset_regression/record/train.tfrecords'
    create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
    train_nums=get_example_nums(train_record_output)
    print("save train example nums={}".format(train_nums))
    # 测试显示函数
    # disp_records(train_record_output,resize_height, resize_width)
    # 产生val.record文件
    image_dir='dataset_regression/images'
    val_labels = 'dataset_regression/val.txt'  # 图片路径
    val_record_output = 'dataset_regression/record/val.tfrecords'
    create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
    val_nums=get_example_nums(val_record_output)
    print("save val example nums={}".format(val_nums))
    #
    # # 测试显示函数
    # # disp_records(train_record_output,resize_height, resize_width)
    # batch_test(val_record_output,resize_height, resize_width)

(3)生成多个record文件的方法

      上述该代码只能保存为单个record文件,当图片数据很多时候,会导致单个record文件超级巨大的情况,解决方法就是,将数据分成多个record文件保存,读取时,只需要将多个record文件的路径列表交给“tf.train.string_input_producer”。可以设置参数batchSize的大小,比如batchSize=2000,表示每2000张图片保存为一个*.tfrecords,这样可以避免单个record文件过大的情况。

      完整代码如下:

# -*-coding: utf-8 -*-
"""
    @Project: tf_record_demo
    @File   : tf_record_batchSize.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-07-27 17:19:54
    @desc   : 将图片数据保存为多个record文件
"""

##########################################################################

import tensorflow as tf
import numpy as np
import os
import cv2
import math
import matplotlib.pyplot as plt
import random
from PIL import Image


##########################################################################
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成实数型的属性
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def show_image(title,image):
    '''
    显示图片
    :param title: 图像标题
    :param image: 图像的数据
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')    # 关掉坐标轴为 off
    plt.title(title)  # 图像题目
    plt.show()

def load_labels_file(filename,labels_num=1):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2,如:test_image/1.jpg 0 2
    :param filename:
    :param labels_num :labels个数
    :return:images type->list
    :return:labels type->list
    '''
    images=[]
    labels=[]
    with open(filename) as f:
        for lines in f.readlines():
            line=lines.rstrip().split(' ')
            label=[]
            for i in range(labels_num):
                label.append(int(line[i+1]))
            images.append(line[0])
            labels.append(label)
    return images,labels

def read_image(filename, resize_height, resize_width):
    '''
    读取图片数据,默认返回的是uint8,[0,255]
    :param filename:
    :param resize_height:
    :param resize_width:
    :return: 返回的图片数据是uint8,[0,255]
    '''

    bgr_image = cv2.imread(filename)
    if len(bgr_image.shape)==2:#若是灰度图则转为三通道
        print("Warning:gray image",filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)

    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#将BGR转为RGB
    # show_image(filename,rgb_image)
    # rgb_image=Image.open(filename)
    if resize_height>0 and resize_width>0:
        rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
    rgb_image=np.asanyarray(rgb_image)
    # show_image("src resize image",image)

    return rgb_image


def create_records(image_dir,file, record_txt_path, batchSize,resize_height, resize_width):
    '''
    实现将图像原始数据,label,长,宽等信息保存为record文件
    注意:读取的图像数据默认是uint8,再转为tf的字符串型BytesList保存,解析请需要根据需要转换类型
    :param image_dir:原始图像的目录
    :param file:输入保存图片信息的txt文件(image_dir+file构成图片的路径)
    :param output_record_txt_dir:保存record文件的路径
    :param batchSize: 每batchSize个图片保存一个*.tfrecords,避免单个文件过大
    :param resize_height:
    :param resize_width:
    PS:当resize_height或者resize_width=0是,不执行resize
    '''
    if os.path.exists(record_txt_path):
        os.remove(record_txt_path)

    setname, ext = record_txt_path.split('.')

    # 加载文件,仅获取一个label
    images_list, labels_list=load_labels_file(file,1)
    sample_num = len(images_list)
    # 打乱样本的数据
    # random.shuffle(labels_list)
    batchNum = int(math.ceil(1.0 * sample_num / batchSize))

    for i in range(batchNum):
        start = i * batchSize
        end = min((i + 1) * batchSize, sample_num)
        batch_images = images_list[start:end]
        batch_labels = labels_list[start:end]
        # 逐个保存*.tfrecords文件
        filename = setname + '{0}.tfrecords'.format(i)
        print('save:%s' % (filename))

        writer = tf.python_io.TFRecordWriter(filename)
        for i, [image_name, labels] in enumerate(zip(batch_images, batch_labels)):
            image_path=os.path.join(image_dir,batch_images[i])
            if not os.path.exists(image_path):
                print('Err:no image',image_path)
                continue
            image = read_image(image_path, resize_height, resize_width)
            image_raw = image.tostring()
            print('image_path=%s,shape:( %d, %d, %d)' % (image_path,image.shape[0], image.shape[1], image.shape[2]),'labels:',labels)
            # 这里仅保存一个label,多label适当增加"'label': _int64_feature(label)"项
            label=labels[0]
            example = tf.train.Example(features=tf.train.Features(feature={
                'image_raw': _bytes_feature(image_raw),
                'height': _int64_feature(image.shape[0]),
                'width': _int64_feature(image.shape[1]),
                'depth': _int64_feature(image.shape[2]),
                'label': _int64_feature(label)
            }))
            writer.write(example.SerializeToString())
        writer.close()

        # 用txt保存*.tfrecords文件列表
        # record_list='{}.txt'.format(setname)
        with open(record_txt_path, 'a') as f:
            f.write(filename + '\n')

def read_records(filename,resize_height, resize_width):
    '''
    解析record文件
    :param filename:保存*.tfrecords文件的txt文件路径
    :return:
    '''
    # 读取txt中所有*.tfrecords文件
    with open(filename, 'r') as f:
        lines = f.readlines()
        files_list=[]
        for line in lines:
            files_list.append(line.rstrip())

    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer(files_list,shuffle=False)
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        }
    )
    tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#获得图像原始的数据

    tf_height = features['height']
    tf_width = features['width']
    tf_depth = features['depth']
    tf_label = tf.cast(features['label'], tf.int32)
    # tf_image=tf.reshape(tf_image, [-1])    # 转换为行向量
    tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 设置图像的维度
    # 存储的图像类型为uint8,这里需要将类型转为tf.float32
    # tf_image = tf.cast(tf_image, tf.float32)
    # [1]若需要归一化请使用:
    tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)# 归一化
    # tf_image = tf.cast(tf_image, tf.float32) * (1. / 255)  # 归一化
    # [2]若需要归一化,且中心化,假设均值为0.5,请使用:
    # tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化
    return tf_image, tf_height,tf_width,tf_depth,tf_label

def disp_records(record_file,resize_height, resize_width,show_nums=4):
    '''
    解析record文件,并显示show_nums张图片,主要用于验证生成record文件是否成功
    :param tfrecord_file: record文件路径
    :param resize_height:
    :param resize_width:
    :param show_nums: 默认显示前四张照片

    :return:
    '''
    tf_image, tf_height, tf_width, tf_depth, tf_label = read_records(record_file,resize_height, resize_width)  # 读取函数
    # 显示前show_nums个图片
    init_op = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(show_nums):
            image,height,width,depth,label = sess.run([tf_image,tf_height,tf_width,tf_depth,tf_label])  # 在会话中取出image和label
            # image = tf_image.eval()
            # 直接从record解析的image是一个向量,需要reshape显示
            # image = image.reshape([height,width,depth])
            print('shape:',image.shape,'label:',label)
            # pilimg = Image.fromarray(np.asarray(image_eval_reshape))
            # pilimg.show()
            show_image("image:%d"%(label),image)
        coord.request_stop()
        coord.join(threads)


def batch_test(record_file,resize_height, resize_width):
    '''
    :param record_file: record文件路径
    :param resize_height:
    :param resize_width:
    :return:
    :PS:image_batch, label_batch一般作为网络的输入
    '''

    tf_image,tf_height,tf_width,tf_depth,tf_label = read_records(record_file,resize_height, resize_width) # 读取函数

    # 使用shuffle_batch可以随机打乱输入:
    # shuffle_batch用法:https://blog.csdn.net/u013555719/article/details/77679964
    min_after_dequeue = 100#该值越大,数据越乱,必须小于capacity
    batch_size = 4
    # capacity = (min_after_dequeue + (num_threads + a small safety margin∗batchsize)
    capacity = min_after_dequeue + 3 * batch_size#容量:一个整数,队列中的最大的元素数

    image_batch, label_batch = tf.train.shuffle_batch([tf_image, tf_label],
                                                      batch_size=batch_size,
                                                      capacity=capacity,
                                                      min_after_dequeue=min_after_dequeue)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:  # 开始一个会话
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(4):
            # 在会话中取出images和labels
            images, labels = sess.run([image_batch, label_batch])
            # 这里仅显示每个batch里第一张图片
            show_image("image", images[0, :, :, :])
            print(images.shape, labels)
        # 停止所有线程
        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    # 参数设置
    image_dir='dataset/train'
    train_file = 'dataset/train.txt'  # 图片路径
    output_record_txt = 'dataset/record/record.txt'#指定保存record的文件列表
    resize_height = 224  # 指定存储图片高度
    resize_width = 224  # 指定存储图片宽度
    batchSize=8000     #batchSize一般设置为8000,即每batchSize张照片保存为一个record文件
    # 产生record文件
    create_records(image_dir=image_dir,
                   file=train_file,
                   record_txt_path=output_record_txt,
                   batchSize=batchSize,
                   resize_height=resize_height,
                   resize_width=resize_width)

    # 测试显示函数
    disp_records(output_record_txt,resize_height, resize_width)

    # batch_test(output_record_txt,resize_height, resize_width)

1.2. 直接文件读取方式 

    上面介绍的是如何将数据转存为TFrecord文件,训练时再解析TFrecord。这种转存为TFrecord数据格式的方法,虽然高效,但也丧失了灵活性,特别是新增数据或者删除相关数据时,这时就不得不重新制作TFrecord数据了。这就挺麻烦啦,如果不想转为TFrecord文件,可以直接读取图像文件进行训练。

    这种方法比较简单,灵活性很强,但效率很低,因为每次迭代训练,GPU/CPU都要等待数据读取I/O操作,图像文件读取以及预处理过程本身就很耗时,甚至比你迭代一次网络还耗时。解决的方法,就是采用tf.data.Dataset数据读取机制。

    直接文件读取方式的完整代码可以参考如下:

    假设我们有train.txt的文件数据如下:

0.jpg 0
1.jpg 0
2.jpg 0
3.jpg 0
4.jpg 0
5.jpg 1
6.jpg 1
7.jpg 1
8.jpg 1
9.jpg 1

    可以使用下面的方法直接读取图像数据,并产生一个batch的训练数据:

# -*-coding: utf-8 -*-
"""
    @Project: tf_record_demo
    @File   : tf_read_files.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-10-14 10:44:06
"""
import tensorflow as tf
import glob
import numpy as np
import os
import matplotlib.pyplot as plt

import cv2
def show_image(title, image):
    '''
    显示图片
    :param title: 图像标题
    :param image: 图像的数据
    :return:
    '''
    # plt.imshow(image, cmap='gray')
    plt.imshow(image)
    plt.axis('on')  # 关掉坐标轴为 off
    plt.title(title)  # 图像题目
    plt.show()


def tf_read_image(filename, resize_height, resize_width):
    '''
    读取图片
    :param filename:
    :param resize_height:
    :param resize_width:
    :return:
    '''
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    # tf_image = tf.cast(image_decoded, tf.float32)
    tf_image = tf.cast(image_decoded, tf.float32) * (1. / 255.0)  # 归一化
    if resize_width>0 and resize_height>0:
        tf_image = tf.image.resize_images(tf_image, [resize_height, resize_width])
    # tf_image = tf.image.per_image_standardization(tf_image)  # 标准化[0,1](减均值除方差)
    return tf_image


def get_batch_images(image_list, label_list, batch_size, labels_nums, resize_height, resize_width, one_hot=False, shuffle=False):
    '''
    :param image_list:图像
    :param label_list:标签
    :param batch_size:
    :param labels_nums:标签个数
    :param one_hot:是否将labels转为one_hot的形式
    :param shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False
    :return:返回batch的images和labels
    '''
    # 生成队列
    image_que, tf_label = tf.train.slice_input_producer([image_list, label_list], shuffle=shuffle)
    tf_image = tf_read_image(image_que, resize_height, resize_width)
    min_after_dequeue = 200
    capacity = min_after_dequeue + 3 * batch_size  # 保证capacity必须大于min_after_dequeue参数值
    if shuffle:
        images_batch, labels_batch = tf.train.shuffle_batch([tf_image, tf_label],
                                                            batch_size=batch_size,
                                                            capacity=capacity,
                                                            min_after_dequeue=min_after_dequeue)
    else:
        images_batch, labels_batch = tf.train.batch([tf_image, tf_label],
                                                    batch_size=batch_size,
                                                    capacity=capacity)
    if one_hot:
        labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
    return images_batch, labels_batch


def load_image_labels(filename):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1,如:test_image/1.jpg 0
    :param filename:
    :return:
    '''
    images_list = []
    labels_list = []
    with open(filename) as f:
        lines = f.readlines()
        for line in lines:
            # rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
            content = line.rstrip().split(' ')
            name = content[0]
            labels = []
            for value in content[1:]:
                labels.append(int(value))
            images_list.append(name)
            labels_list.append(labels)
    return images_list, labels_list


def batch_test(filename, image_dir):
    labels_nums = 2
    batch_size = 4
    resize_height = 200
    resize_width = 200
    image_list, label_list = load_image_labels(filename)
    image_list=[os.path.join(image_dir,image_name) for image_name in image_list]

    image_batch, labels_batch = get_batch_images(image_list=image_list,
                                                 label_list=label_list,
                                                 batch_size=batch_size,
                                                 labels_nums=labels_nums,
                                                 resize_height=resize_height, resize_width=resize_width,
                                                 one_hot=False, shuffle=True)
    with tf.Session() as sess:  # 开始一个会话
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(4):
            # 在会话中取出images和labels
            images, labels = sess.run([image_batch, labels_batch])
            # 这里仅显示每个batch里第一张图片
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))

        # 停止所有线程
        coord.request_stop()
        coord.join(threads)


if __name__ == "__main__":
    image_dir = "./dataset/train"
    filename = "./dataset/train.txt"
    batch_test(filename, image_dir)



2.tf.data.Dataset数据读取机制:Pipeline机制

    要执行训练步骤,您必须首先提取并转换训练数据,然后将其提供给在加速器上运行的模型。然而,在一个简单的同步执行中,当 CPU 正在准备数据时,加速器则处于空闲状态。相反,当加速器正在训练模型时,CPU 则处于空闲状态。因此,训练步骤时间是 CPU 预处理时间和加速器训练时间的总和

prefetch(必须放在最后)

     TensorFlow引入了tf.data.Dataset模块,使其数据读入的操作变得更为方便,而支持多线程(进程)的操作,也在效率上获得了一定程度的提高。使用tf.data.Dataset模块的pipline机制,可实现CPU多线程处理输入的数据,如读取图片和图片的一些的预处理,这样GPU可以专注于训练过程,而CPU去准备数据。

    参考资料:

https://blog.csdn.net/u014061630/article/details/80776975

(五星推荐)TensorFlow全新的数据读取方式:Dataset API入门教程:http://baijiahao.baidu.com/s?id=1583657817436843385&wfr=spider&for=pc

    Pipelining 将一个训练步骤的预处理和模型执行重叠。当加速器正在执行训练步骤 N 时,CPU 正在准备步骤 N + 1 的数据。这样做的目的是可以将步骤时间缩短到极致,包含训练以及提取和转换数据所需时间(而不是总和)。

    如果没有使用 pipelining,则 CPU 和 GPU / TPU 在大部分时间处于闲置状态:

    而使用 pipelining 技术后,空闲时间显著减少:

    tf.data API 通过 tf.data.Dataset.prefetch 转换提供了一个软件 pipelining 操作机制,该转换可用于将数据生成的时间与所消耗时间分离。特别是,转换使用后台线程和内部缓冲区,以便在请求输入数据集之前从输入数据集中预提取元素。因此,为了实现上面说明的 pipelining 效果,您可以将 prefetch(1) 添加为数据集管道的最终转换(如果单个训练步骤消耗 n 个元素,则添加 prefetch(n))。

    tf.data.Dataset.prefetch 提供了 software pipelining 机制。该函数解耦了 数据产生的时间 和 数据消耗的时间。具体来说,该函数有一个后台线程和一个内部缓存区,在数据被请求前,就从 dataset 中预加载一些数据(进一步提高性能)。prefech(n) 一般作为最后一个 transformation,其中 n 为 batch_size。 prefetch 的使用方法如下:   

    要将此更改应用于我们的运行示例,请将:

dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset

    更改为:

dataset = dataset.batch(batch_size=FLAGS.batch_size)
dataset = dataset.prefetch(buffer_size=FLAGS.prefetch_buffer_size)
return dataset

    请注意,在任何时候只要有机会将 “制造者” 的工作与 “消费者” 的工作重叠,预取转换就会产生效益。前面的建议只是最常见的应用程序。

map

    使用 tf.data.Dataset.map,我们可以很方便地对数据集中的各个元素进行预处理。因为输入元素之间时独立的,所以可以在多个 CPU 核心上并行地进行预处理。map 变换提供了一个 num_parallel_calls参数去指定并行的级别。

    准备批处理时,可能需要预处理输入元素。为此,tf.data API 提供了 tf.data.Dataset.map 转换,它将用户定义的函数(例如,运行示例中的 parse_fn)应用于输入数据集的每个元素。由于输入元素彼此独立,因此可以跨多个 CPU 内核并行化预处理。为了实现这一点,map 转换提供了 thenum_parallel_calls 参数来指定并行度。例如,下图说明了将 num_parallel_calls = 2 设置为 map 转换的效果:

dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)

repeat

    repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:

    如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常

实例代码1:dataset.make_initializable_iterator()

# -*-coding: utf-8 -*-
"""
    @Project: fine tuning
    @File   : pipeline.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-11-17 20:18:54
"""
import tensorflow as tf
import numpy as np
import glob
import matplotlib.pyplot as plt

width=0
height=0
def show_image(title, image):
    '''
    显示图片
    :param title: 图像标题
    :param image: 图像的数据
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')  # 关掉坐标轴为 off
    plt.title(title)  # 图像题目
    plt.show()


def tf_read_image(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image_decoded, tf.float32)
    if width>0 and height>0:
        image = tf.image.resize_images(image, [height, width])
    image = tf.cast(image, tf.float32) * (1. / 255.0)  # 归一化
    return image, label


def input_fun(files_list, labels_list, batch_size, shuffle=True):
    '''
    :param files_list:
    :param labels_list:
    :param batch_size:
    :param shuffle:
    :return:
    '''
    # 构建数据集
    dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))
    if shuffle:
        dataset = dataset.shuffle(100)
    dataset = dataset.repeat()  # 空为无限循环
    dataset = dataset.map(tf_read_image, num_parallel_calls=4)  # num_parallel_calls一般设置为cpu内核数量
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(2)  # software pipelining 机制
    return dataset


if __name__ == '__main__':
    data_dir = 'dataset/image/*.jpg'
    # labels_list = tf.constant([0,1,2,3,4])
    # labels_list = [1, 2, 3, 4, 5]
    files_list = glob.glob(data_dir)
    labels_list = np.arange(len(files_list))
    num_sample = len(files_list)
    batch_size = 1
    dataset = input_fun(files_list, labels_list, batch_size=batch_size, shuffle=False)

    # 需满足:max_iterate*batch_size <=num_sample*num_epoch,否则越界
    max_iterate = 3
    with tf.Session() as sess:
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.make_initializer(dataset)
        sess.run(init_op)
        iterator = iterator.get_next()
        for i in range(max_iterate):
            images, labels = sess.run(iterator)
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))

实例代码2:dataset.make_one_shot_iterator()

     上面的迭代器是使用dataset.make_initializable_iterator(),当然一个更简单的方法是使用dataset.make_one_shot_iterator(),下面的代码,可把dataset.make_one_shot_iterator()放在input_fun函数中,直接返回一个迭代器iterator

# -*-coding: utf-8 -*-
"""
    @Project: fine tuning
    @File   : pipeline.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-11-17 20:18:54
"""
import tensorflow as tf
import numpy as np
import glob
import matplotlib.pyplot as plt

width = 224
height = 224


def show_image(title, image):
    '''
    显示图片
    :param title: 图像标题
    :param image: 图像的数据
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')  # 关掉坐标轴为 off
    plt.title(title)  # 图像题目
    plt.show()


def tf_read_image(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image_decoded, tf.float32)
    if width > 0 and height > 0:
        image = tf.image.resize_images(image, [height, width])
    image = tf.cast(image, tf.float32) * (1. / 255.0)  # 归一化
    return image, label


def input_fun(files_list, labels_list, batch_size, shuffle=True):
    '''
    :param files_list:
    :param labels_list:
    :param batch_size:
    :param shuffle:
    :return:
    '''
    # 构建数据集
    dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))
    if shuffle:
        dataset = dataset.shuffle(100)
    dataset = dataset.repeat()  # 空为无限循环
    dataset = dataset.map(tf_read_image, num_parallel_calls=4)  # num_parallel_calls一般设置为cpu内核数量
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(2)  # software pipelining 机制

    iterator = dataset.make_one_shot_iterator()
    return iterator


if __name__ == '__main__':
    data_dir = './data/demo_data/*.jpg'
    # labels_list = tf.constant([0,1,2,3,4])
    # labels_list = [1, 2, 3, 4, 5]
    files_list = glob.glob(data_dir)
    labels_list = np.arange(len(files_list))
    num_sample = len(files_list)
    batch_size = 4
    iterator = input_fun(files_list, labels_list, batch_size=batch_size, shuffle=False)

    # 需满足:max_iterate*batch_size <=num_sample*num_epoch,否则越界
    max_iterate = 3
    with tf.Session() as sess:
        # iterator = dataset.make_initializable_iterator()
        # init_op = iterator.make_initializer(dataset)
        # sess.run(init_op)
        iterator = iterator.get_next()
        for i in range(max_iterate):
            images, labels = sess.run(iterator)
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))

实例代码3:产生用于训练的图和label

假设train.txt的数据如下:

0_8354.jpg 8 3 5 4
1_3621.jpg 3 6 2 1
2_4326.jpg 4 3 2 6
3_7711.jpg 7 7 1 1
# -*-coding: utf-8 -*-
"""
    @Project: verification_code
    @File   : dataset.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2019-03-03 18:45:13
"""
import tensorflow as tf
import numpy as np
import glob
import os
import matplotlib.pyplot as plt
from utils import file_processing,image_processing

print("TF Version:{}".format(tf.__version__))

resize_height = 0  # 指定存储图片高度
resize_width = 0  # 指定存储图片宽度

def load_image_labels(filename):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签1,如:test_image/1.jpg 0 2
    :param filename:
    :return:
    '''
    images_list=[]
    labels_list=[]
    with open(filename) as f:
        lines = f.readlines()
        for line in lines:
            #rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
            content=line.rstrip().split(' ')
            name=content[0]
            labels=[]
            for value in content[1:]:
                labels.append(int(value))
            images_list.append(name)
            labels_list.append(labels)
    return images_list,labels_list

def show_image(title, image):
    '''
    显示图片
    :param title: 图像标题
    :param image: 图像的数据
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')  # 关掉坐标轴为 off
    plt.title(title)  # 图像题目
    plt.show()

def tf_resize_image(image, width=0, height=0):
    if (width is None) or (height is None):  # 错误写法:resize_height and resize_width is None
        return image
    image = tf.image.resize_images(image, [height, width])
    return image


def tf_read_image(file, width, height):
    image_string = tf.read_file(file)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image_decoded, tf.float32)
    image=tf_resize_image(image, width, height)
    image = tf.cast(image, tf.float32) * (1. / 255.0)  # 归一化
    return image

def map_read_image(files_list, labels_list):
    tf_image=tf_read_image(files_list,resize_width,resize_height)
    return tf_image,labels_list

def input_fun(files_list, labels_list, batch_size, shuffle=True):
    '''
    :param orig_image:
    :param dest_image:
    :param batch_size:
    :param num_epoch:
    :param shuffle:
    :return:
    '''
    # 构建数据集
    dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))#TF version>=1.4
    # dataset = tf.contrib.data.Dataset.from_tensor_slices((files_list, labels_list))#TF version<1.4

    if shuffle:
        dataset = dataset.shuffle(100)
    dataset = dataset.repeat()  # 空为无限循环
    # dataset = dataset.map(map_read_image, num_parallel_calls=4)  # num_parallel_calls一般设置为cpu内核数量
    dataset = dataset.map(map_read_image)  # num_parallel_calls一般设置为cpu内核数量

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(2)  # software pipelining 机制
    dataset = dataset.make_one_shot_iterator()
    return dataset


def get_image_data(images_list, image_dir,labels_list, batch_size, re_height, re_width, shuffle=False):
    global resize_height
    global resize_width
    resize_height = re_height  # 指定存储图片高度
    resize_width = re_width    # 指定存储图片宽度
    image_list = [os.path.join(image_dir, name) for name in images_list]
    dataset = input_fun(image_list, labels_list, batch_size, shuffle)
    return dataset

if __name__ == '__main__':
    filename='../dataset/train.txt'
    image_dir="E:/TensoFlow/verification_code/dataset/train"
    images_list, labels_list=load_image_labels(filename)
    batch_size = 4
    dataset=get_image_data(images_list, image_dir,labels_list, batch_size, re_height=None, re_width=None, shuffle=False)
    # 需满足:max_iterate*batch_size <=num_sample*num_epoch,否则越界
    max_iterate = 3
    with tf.Session() as sess:
        # dataset = dataset.make_initializable_iterator()
        # init_op = dataset.make_initializer(dataset)
        # sess.run(init_op)
        dataset = dataset.get_next()
        for i in range(max_iterate):
            images, labels = sess.run(dataset)
            print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))
            show_image("image", images[0, :, :, :])

实例代码4:产生用于训练的原始图和target目标图

# -*-coding: utf-8 -*-
"""
    @Project: triple_path_networks
    @File   : load_data.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-11-29 11:40:37
"""

import tensorflow as tf

import glob
import numpy as np
import utils.image_processing as image_processing
import os
print("TF Version:{}".format(tf.__version__))

resize_height = 0  # 指定存储图片高度
resize_width = 0   # 指定存储图片宽度

def write_data(file, content_list, model):
    with open(file, mode=model) as f:
        for line in content_list:
            f.write(line + "\n")

def read_data(file):
    with open(file, mode="r") as f:
        content_list = f.readlines()
        content_list = [content.rstrip() for content in content_list]
    return content_list

def read_train_val_data(filename,factor=0.8):
    image_list = read_data(filename)
    trian_num=int(len(image_list)*factor)
    train_list = image_list[:trian_num]
    val_list = image_list[trian_num:]
    print("data info***************************")
    print("--train nums:{}".format(len(train_list)))
    print("--val   nums:{}".format(len(val_list)))
    print("************************************")
    return train_list,val_list

def tf_resize_image(image,width=0,height=0):
    if height>0 and width>0:
        image = tf.image.resize_images(image, [height, width])
    return image

def tf_read_image(file,width=224,height=224):
    image_string = tf.read_file(file)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image_decoded, tf.float32)
    if height>0 and width>0:
        image = tf.image.resize_images(image, [height, width])
    image = tf.cast(image, tf.float32) * (1. / 255.0)  # 归一化
    return image

def map_read_image(orig_file, dest_file):
    orig_image=tf_read_image(orig_file,resize_width,resize_height)
    dest_image=tf_read_image(dest_file,resize_width,resize_height)
    return orig_image,dest_image

def input_fun(orig_image, dest_image, batch_size, shuffle=True):
    '''
    :param orig_image:
    :param dest_image:
    :param batch_size:
    :param num_epoch:
    :param shuffle:
    :return:
    '''
    # 构建数据集
    # dataset = tf.data.Dataset.from_tensor_slices((orig_image, dest_image))
    dataset = tf.contrib.data.Dataset.from_tensor_slices((orig_image, dest_image))

    if shuffle:
        dataset = dataset.shuffle(100)
    dataset = dataset.repeat()  #空为无限循环
    # dataset = dataset.map(map_read_image, num_parallel_calls=4)  # num_parallel_calls一般设置为cpu内核数量
    dataset = dataset.map(map_read_image)  # num_parallel_calls一般设置为cpu内核数量

    dataset = dataset.batch(batch_size)
    # dataset = dataset.prefetch(2)  # software pipelining 机制
    return dataset

def get_image_data(file_list,orig_dir,dest_dir,batch_size,re_height,re_width,shuffle=False):
    global resize_height
    global resize_width
    resize_height = re_height  # 指定存储图片高度
    resize_width = re_width  # 指定存储图片宽度

    orig_image_list=[os.path.join(orig_dir,name) for name in file_list]
    dest_image_list=[os.path.join(dest_dir,name) for name in file_list]
    dataset = input_fun(orig_image_list, dest_image_list, batch_size=batch_size,shuffle=shuffle)
    return dataset

if __name__ == '__main__':
    orig_dir="../dataset/blackberry/blackberry"
    dest_dir="../dataset/blackberry/canon"
    filename="../dataset/blackberry/filelist.txt"
    batch_size = 1
    file_list=read_data(filename)
    dataset = get_image_data(file_list,orig_dir, dest_dir, batch_size=batch_size,shuffle=False)
    # 迭代次数:max_iterate=10
    # 需满足:max_iterate*batch_size <=num_sample*num_epoch,否则越界
    max_iterate = 5
    with tf.Session() as sess:
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.make_initializer(dataset)
        sess.run(init_op)
        iterator = iterator.get_next()
        for i in range(max_iterate):
            orig_image, dest_image = sess.run(iterator)
            image_processing.show_image("orig_image", orig_image[0, :, :, :])
            image_processing.show_image("dest_image", dest_image[0, :, :, :])
            print('orig_image:{},dest_image:{}'.format(orig_image.shape, dest_image.shape))

实例代码5: tf.data.Dataset.from_generator

    tf.data.Dataset.from_tensor_slices并不支持输入长度不同list,比如以下代码

t = [[4,2], [3,4,5]]
dataset = tf.data.Dataset.from_tensor_slices(t)

    将会报错:

ValueError: Argument must be a dense tensor: [[4, 2], [3, 4, 5]] - got shape [2], but wanted [2, 2].

    一种决解的方法,采用  tf.data.Dataset.from_generator生成器:

import tensorflow as tf
import numpy as np
data1 = np.array([[1], [2, 3], [3, 4]])
data2 = np.array([[10], [20, 30], [30, 40]])

def data_generator():
    for el, e2 in zip(data1, data2):
        yield el, e2

dataset = tf.data.Dataset.from_generator(data_generator,
                                         output_types=(tf.int32, tf.int32),
                                         output_shapes=(None, None)) #或者output_shapes=(tf.TensorShape([None]), tf.TensorShape([None]))

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
max_iter = 3
with tf.Session() as sess:
    for i in range(max_iter):
        d1, d2 = sess.run(next_element)
        print("d1:{}".format(d1))
        print("d2:{}".format(d2))
        print("******************************")

  输出:

d1:[1]
d2:[10]
******************************
d1:[2 3]
d2:[20 30]
******************************
d1:[3 4]
d2:[30 40]
******************************

参考资料:

https://stackoverflow.com/questions/47580716/how-to-input-a-list-of-lists-with-different-sizes-in-tf-data-dataset

https://blog.csdn.net/foreseerwang/article/details/80572182

https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369(五星推荐)


3. 用Python循环产生批量数据batch

     这部分请参考本人的博客《Python循环产生批量数据batch》  https://blog.csdn.net/guyuealian/article/details/83473298

    上面提到的方法都是在TensorFlow提高API接口完成的,数据预处理也必须依赖TensorFlow的API接口。当遇到一些特殊处理,而TensorFlow没有相应的接口时,就比较尴尬。比如要对输入的图像进行边缘检测处理时,这时能想到就是用OpenCV的Canny算法,一种简单的方法就是,每次sess.run()获得图像数据后,再调用OpenCV的Canny算法……是的,有的麻烦!

     这里提供一个我自己设计方法,不依赖TensorFlow,灵活性很强,你可以对数据进行任意的操作,可以使用OpenCV,numpy等任意的库函数。

   TXT文本如下,格式:图片名 label1 label2 ,注意label可以多个

1.jpg 1 11
2.jpg 2 12
3.jpg 3 13
4.jpg 4 14
5.jpg 5 15
6.jpg 6 16
7.jpg 7 17
8.jpg 8 18

    要想产生batch数据,关键是要用到Python的关键字yield,实现一个batch一个batch的返回数据,代码实现主要有两个方法:

def get_data_batch(inputs, batch_size=None, shuffle=False):
    '''
    循环产生批量数据batch
    :param inputs: list数据
    :param batch_size: batch大小
    :param shuffle: 是否打乱inputs数据
    :return: 返回一个batch数据
    '''
def get_next_batch(batch):
    return batch.__next__()

    使用时,将数据传到 get_data_batch( )方法,然后使用get_next_batch( )获得一个batch数据,完整的Python代码如下:

# -*-coding: utf-8 -*-
"""
    @Project: create_batch_data
    @File   : create_batch_data.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2017-10-27 18:20:15
"""
import math
import random
import os
import glob
import numpy as np


def get_data_batch(inputs, batch_size=None, shuffle=False):
    '''
    循环产生批量数据batch
    :param inputs: list类型数据,多个list,请[list0,list1,...]
    :param batch_size: batch大小
    :param shuffle: 是否打乱inputs数据
    :return: 返回一个batch数据
    '''
    rows = len(inputs[0])
    indices = list(range(rows))
    # 如果输入是list,则需要转为list
    if shuffle:
        random.seed(100)
        random.shuffle(indices)
    while True:
        batch_indices = np.asarray(indices[0:batch_size])  # 产生一个batch的index
        indices = indices[batch_size:] + indices[:batch_size]  # 循环移位,以便产生下一个batch
        batch_data = []
        for data in inputs:
            data = np.asarray(data)
            temp_data=data[batch_indices] #使用下标查找,必须是ndarray类型类型
            batch_data.append(temp_data.tolist())
        yield batch_data

def get_data_batch2(inputs, batch_size=None, shuffle=False):
    '''
    循环产生批量数据batch
    :param inputs: list类型数据,多个list,请[list0,list1,...]
    :param batch_size: batch大小
    :param shuffle: 是否打乱inputs数据
    :return: 返回一个batch数据
    '''
    # rows,cols=inputs.shape
    rows = len(inputs[0])
    indices = list(range(rows))
    if shuffle:
        random.seed(100)
        random.shuffle(indices)
    while True:
        batch_indices = indices[0:batch_size]  # 产生一个batch的index
        indices = indices[batch_size:] + indices[:batch_size]  # 循环移位,以便产生下一个batch
        batch_data = []
        for data in inputs:
            temp_data = find_list(batch_indices, data)
            batch_data.append(temp_data)
        yield batch_data


def find_list(indices, data):
    out = []
    for i in indices:
        out = out + [data[i]]
    return out


def get_list_batch(inputs, batch_size=None, shuffle=False):
    '''
    循环产生batch数据
    :param inputs: list数据
    :param batch_size: batch大小
    :param shuffle: 是否打乱inputs数据
    :return: 返回一个batch数据
    '''
    if shuffle:
        random.shuffle(inputs)
    while True:
        batch_inouts = inputs[0:batch_size]
        inputs = inputs[batch_size:] + inputs[:batch_size]  # 循环移位,以便产生下一个batch
        yield batch_inouts


def load_file_list(text_dir):
    text_dir = os.path.join(text_dir, '*.txt')
    text_list = glob.glob(text_dir)
    return text_list


def get_next_batch(batch):
    return batch.__next__()


def load_image_labels(finename):
    '''
    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签1,如:test_image/1.jpg 0 2
    :param test_files:
    :return:
    '''
    images_list = []
    labels_list = []
    with open(finename) as f:
        lines = f.readlines()
        for line in lines:
            # rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
            content = line.rstrip().split(' ')
            name = content[0]
            labels = []
            for value in content[1:]:
                labels.append(float(value))
            images_list.append(name)
            labels_list.append(labels)
    return images_list, labels_list


if __name__ == '__main__':
    filename = './training_data/test.txt'
    images_list, labels_list = load_image_labels(filename)

    # 若输入为np.arange数组,则需要tolist()为list类型,如:
    # images_list = np.reshape(np.arange(8*3), (8,3))
    # labels_list = np.reshape(np.arange(8*3), (8,3))
    # images_list=images_list.tolist()
    # labels_list=labels_list.tolist()

    iter = 5  # 迭代3次,每次输出一个batch个
    # batch = get_data_batch([images_list, labels_list], batch_size=3, shuffle=False)
    batch = get_data_batch2(inputs=[images_list,labels_list], batch_size=5, shuffle=True)

    for i in range(iter):
        print('**************************')
        batch_images, batch_labels = get_next_batch(batch)
        print('batch_images:{}'.format(batch_images))
        print('batch_labels:{}'.format(batch_labels))


   运行输出结果为:

**************************
batch_images:['1.jpg', '2.jpg', '3.jpg']
batch_labels:[[1.0, 11.0], [2.0, 12.0], [3.0, 13.0]]
**************************
batch_images:['4.jpg', '5.jpg', '6.jpg']
batch_labels:[[4.0, 14.0], [5.0, 15.0], [6.0, 16.0]]
**************************
batch_images:['7.jpg', '8.jpg', '1.jpg']
batch_labels:[[7.0, 17.0], [8.0, 18.0], [1.0, 11.0]]
**************************
batch_images:['2.jpg', '3.jpg', '4.jpg']
batch_labels:[[2.0, 12.0], [3.0, 13.0], [4.0, 14.0]]
**************************
batch_images:['5.jpg', '6.jpg', '7.jpg']
batch_labels:[[5.0, 15.0], [6.0, 16.0], [7.0, 17.0]]

Process finished with exit code 0


4.参考资料:

[1]https://blog.csdn.net/happyhorizion/article/details/77894055  (五星推荐)

[2]https://blog.csdn.net/ywx1832990/article/details/78462582

[3]https://blog.csdn.net/csuzhaoqinghui/article/details/51377941

[4]《tf.data API,构建高性能 TensorFlow 输入管道

猜你喜欢

转载自blog.csdn.net/guyuealian/article/details/85106012