tensorflow使用——(六)tfrecord数据操作

前言:

为了更加展示tfrecord数据的相关操作,笔者后续又写了一个实践的简单例子进一步解释,具体可以看:

https://blog.csdn.net/weixin_42001089/article/details/90236241

正文:

tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等.

一首先是转化为tfrecord文件格式:


  
  
  1. with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  2. #生成protocol数据类型
  3. example = image_to_tfexample(image, labels)
  4. tfrecord_writer.write(example.SerializeToString())

其中output_filename就是定义输出的文件位置如./datasets/train.tfrecords

然后通过tf.python_io.TFRecordWriter class中的write方法将tfrecord文件写入到output_filename

一般的话这里会将数据集分成测试集和训练集,所以可以这样定义一个生成tfrecord的函数:


  
  
  1. def gen_tfrecord(split_name, filenames, dataset_dir):
  2. assert split_name in [ 'train', 'test']
  3. with tf.Session() as sess:
  4. #定义tfrecord文件的路径+名字
  5. output_filename = os.path.join(dataset_dir,split_name + '.tfrecords')
  6. with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  7. for i,filename in enumerate(filenames):
  8. try:
  9. sys.stdout.write( '\r>> Converting image %d/%d' % (i+ 1, len(filenames)))
  10. sys.stdout.flush()
  11. #生成protocol数据类型
  12. example = image_to_tfexample(image, labels)
  13. tfrecord_writer.write(example.SerializeToString())
  14. except IOError as e:
  15. print( 'Could not read:',filename)
  16. print( 'Error:',e)
  17. sys.stdout.write( '\n')
  18. sys.stdout.flush()

split_name就是指定是训练集还是测试集

filenames是每一张图片的路径

然后外部通过下面切分为训练集还是测试集


  
  
  1. training_filenames = photo_filenames[FLAGS. test_num:]
  2. testing_filenames = photo_filenames[ :FLAGS.test_num]

然后调用即可


  
  
  1. gen_tfrecord( 'train', training_filenames,FLAGS.dataset_dir)
  2. gen_tfrecord( 'test', testing_filenames, FLAGS.dataset_dir)

其中dataset_dir就是图片的目录./datasets/images/

关于image和labels,对源数据集做预处理得到的

例如image的获得:


  
  
  1. #读取图片
  2. image = Image.open(filename)
  3. #根据模型的结构resize
  4. image = image_data.resize(( 224, 224))
  5. #灰度化
  6. image = np. array(image_data.convert( 'L'))
  7. #将图片转化为bytes
  8. image= image_data.tobytes()

labels的获得:


  
  
  1. #获取label
  2. labels = filename. split( '/')[- 1][ 0: 4]

关于二者的获取可以在外部处理也可以集成到gen_tfrecord函数中,生成的tfrecord也同时放到了图片目录下。

故:


  
  
  1. def gen_tfrecord(split_name, filenames, dataset_dir):
  2. assert split_name in [ 'train', 'test']
  3. with tf.Session() as sess:
  4. #定义tfrecord文件的路径+名字
  5. output_filename = os.path.join(dataset_dir,split_name + '.tfrecords')
  6. with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  7. for i,filename in enumerate(filenames):
  8. try:
  9. sys.stdout.write( '\r>> Converting image %d/%d' % (i+ 1, len(filenames)))
  10. sys.stdout.flush()
  11. #读取图片
  12. image_data = Image.open(filename)
  13. #根据模型的结构resize
  14. image_data = image_data.resize(( 224, 224))
  15. #灰度化
  16. image_data = np.array(image_data.convert( 'L'))
  17. #将图片转化为bytes
  18. image_data = image_data.tobytes()
  19. #获取label
  20. labels = filename.split( '/')[ -1][ 0: 4]
  21. #生成protocol数据类型
  22. example = image_to_tfexample(image_data, labels)
  23. tfrecord_writer.write(example.SerializeToString())
  24. except IOError as e:
  25. print( 'Could not read:',filename)
  26. print( 'Error:',e)
  27. sys.stdout.write( '\n')
  28. sys.stdout.flush()

上面的数据集中图片的名字正好是其label,当然二者获取的方式不尽相同,也有可能是images目录下对应三个子目录例如:

cat,dog,fish然后每个目录下面是对应的图片,这时候预处理无非就是要变,定义的gen_tfrecord函数也可能稍微变一下

好了接下来说一下当中的 image_to_tfexample函数,这个也需要自己定义:


  
  
  1. def image_to_tfexample(image, label):
  2. return tf.train.Example(features=tf.train.Features(feature={
  3. 'image': tf.train.Feature(bytes_list = tf.train.BytesList(value=[img]))
  4. 'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[label])),
  5. }))

tf.train.Example 是一块 buffer即协议缓冲区,其中包含了各种feature并以字典的形式赋值。

关于这里说明三点:

1 正如value=[XXX]其中XXX必须是列表形式,也就是说如果传进来的label是列表,那么就可以这样写:

'label': tf.train.Feature(int64_list = tf.train.Int64List(value=label)
  
  

2 传进来的不必要非得是一个label,什么意思呢?比如多任务的时候,可以将label(比如验证码:256)拆分为多个label,每个数字代表一个label


  
  
  1. #获取label
  2. labels = filename.split( '/')[ -1][ 0: 3]
  3. num_labels = []
  4. for j in range( 3):
  5. num_labels. append( int(labels[j]))

那么就是这样:

example = image_to_tfexample(image_data, num_labels[0], num_labels[1], num_labels[2])
  
  

  
  
  1. def image_to_tfexample(image_data, label0, label1, label2):
  2. return tf.train.Example(features=tf.train.Features(feature={
  3. 'image': tf.train.Feature(bytes_list = tf.train.BytesList(value=[img]))
  4. 'label0': tf.train.Feature(int64_list = tf.train.Int64List(value=label0)),
  5. 'label1': tf.train.Feature(int64_list = tf.train.Int64List(value=label1)),
  6. 'label2': tf.train.Feature(int64_list = tf.train.Int64List(value=label2)),
  7. }))

3 tfrecord支持的格式除了上面的整型和二进制二种格式,还支持浮点数即

tf.train.Feature(bytes_list = tf.train.FloatList(value=[float_scalar]))
  
  

二 读取tfrecord文件:

首先介绍下epoch,比如我们要训练的数据集是10张图片即1.jpg、2.jpg、3.jpg、....10.jpg,我们先全部训练这10张图片,然后还可以再来一轮,就是再用这十张图片训练一次,这里的epoch就是轮数当epoch=20时,就是用20轮数据集。

所以过程是这样的:

当epoch=10时,

现将10张图片全部都装载到文件名队列(q1)(装载10次)

q1
1.jpg
2.jpg
.......
10.jpg
1.jpg
2.jpg
.......
10.jpg
................
................
1.jpg
2.jpg
.......
10.jpg

当我们的程序sess.run运行时,内存队列(q2)会从q1依次读取10图片到q2

tf中提供了相关API

使用tf.train.string_input_producer函数,系统会自动将它转为一个文件名队列,其有两个参数num_epochs和shuffle

num_epochs就是epoch数,shuffle就是说是否将一个epoch内文件的顺序是打乱,即不按照1.jpg、2.jpg、3.jpg、....10.jpg,而是9.jpg、5.jpg、3.jpg、....10.jpg等等,shuffle=Ture时,就是打乱
 

关于内存队列不用我们自己建立,使用各种reader对象从文件名队列中读取数据就可以了,具体到tfrecord的reader即为

  reader = tf.TFRecordReader()

之后调用tf.TFRecordReader的tf.parse_single_example解析器,将Example协议缓冲区(protocol buffer)解析为张量

解析了image后还进行了一些预处理,整体如下:


  
  
  1. def read_and_decode(filename):
  2. # 根据文件名生成一个队列
  3. filename_queue = tf.train.string_input_producer([filename])
  4. reader = tf.TFRecordReader()
  5. # 返回文件名和文件
  6. _, serialized_example = reader.read(filename_queue)
  7. features = tf.parse_single_example(serialized_example,
  8. features={
  9. 'image' : tf.FixedLenFeature([], tf.string),
  10. 'label': tf.FixedLenFeature([], tf.int64),
  11. })
  12. # 获取图片数据
  13. image = tf.decode_raw(features[ 'image'], tf.uint8)
  14. # tf.train.shuffle_batch必须确定shape
  15. image = tf.reshape(image, [ 224, 224])
  16. # 图片预处理
  17. image = tf.cast(image, tf.float32) / 255.0
  18. image = tf.subtract(image, 0.5)
  19. image = tf.multiply(image, 2.0)
  20. # 获取label
  21. label = tf.cast(features[ 'label'], tf.int32)
  22. return image, label

如果是上面所说的多任务有多个标签的相应读取tfrecord函数可以为:


  
  
  1. def read_and_decode(filename):
  2. # 根据文件名生成一个队列
  3. filename_queue = tf.train.string_input_producer([filename])
  4. reader = tf.TFRecordReader()
  5. # 返回文件名和文件
  6. _, serialized_example = reader.read(filename_queue)
  7. features = tf.parse_single_example(serialized_example,
  8. features={
  9. 'image' : tf.FixedLenFeature([], tf.string),
  10. 'label0': tf.FixedLenFeature([], tf.int64),
  11. 'label1': tf.FixedLenFeature([], tf.int64),
  12. 'label2': tf.FixedLenFeature([], tf.int64),
  13. })
  14. # 获取图片数据
  15. image = tf.decode_raw(features[ 'image'], tf.uint8)
  16. # tf.train.shuffle_batch必须确定shape
  17. image = tf.reshape(image, [ 224, 224])
  18. # 图片预处理
  19. image = tf.cast(image, tf.float32) / 255.0
  20. image = tf.subtract(image, 0.5)
  21. image = tf.multiply(image, 2.0)
  22. # 获取label
  23. label0 = tf.cast(features[ 'label0'], tf.int32)
  24. label1 = tf.cast(features[ 'label1'], tf.int32)
  25. label2 = tf.cast(features[ 'label2'], tf.int32)
  26. return image, label0, label1, label2

在实际中要分batch进行读取数据(就是说比如数据集有1000M,每一个batch=10M):

那么在一个epoch内分100次读取,每次转载10M的数据,关于batch,tf一般有两个相关的API即tf.train.batch和tf.train.shuffle_batch

tf.train.batch:


  
  
  1. batch(tensors, batch_size, num_threads= 1, capacity= 32,
  2. enqueue_many= False, shapes= None, dynamic_pad= False,
  3. allow_smaller_final_batch= False, shared_name= None, name= None)
  • 第一个参数tensors:tensor序列或tensor字典,可以是含有单个样本的序列;
  • 第二个参数batch_size: 生成的batch的大小;
  • 第三个参数num_threads:执行tensor入队操作的线程数量,可以设置使用多个线程同时并行执行,提高运行效率,但也不是数量越多越好;
  • 第四个参数capacity: 定义生成的tensor序列的最大容量;
  • 第五个参数enqueue_many: 定义第一个传入参数tensors是多个tensor组成的序列,还是单个tensor;
  • 第六个参数shapes: 可选参数,默认是推测出的传入的tensor的形状;
  • 第七个参数dynamic_pad: 定义是否允许输入的tensors具有不同的形状,设置为True,会把输入的具有不同形状的tensor归一化到相同的形状;
  • 第八个参数allow_smaller_final_batch: 设置为True,表示在tensor队列中剩下的tensor数量不够一个batch_size的情况下,允许最后一个batch的数量少于batch_size, 设置为False,则不管什么情况下,生成的batch都拥有batch_size个样本;
  • 第九个参数shared_name: 可选参数,设置生成的tensor序列在不同的Session中的共享名称;
  • 第十个参数name: 操作的名称;

一般的话只需要定义前三个即:

tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)
  
  

tf.train.shuffle_batch:(和tf.train.batch差不多,最大的差别就是打乱输出一个batch,而不是依次取出一个batch)

tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue,num_threads=threads)
  
  

里面的相同的参数意思也同上。关于min_after_dequeue是出队后,队列至少剩下min_after_dequeue个数据,但其往往是用来定义混乱级别的,即在定义了随机取样的缓冲区大小的时候,min_after_dequeue越大表示更大级别的混合但是会导致启动更加缓慢,并且会占用更多的内存,同时一定要保证这参数小于capacity参数的值,否则会出错。

一般的话:capacity=(min_after_dequeue+(num_threads+a small safety margin∗batchsize)

num_threads指的是线程数

关于tf.train.shuffle_batch更多的可以看官方文档或者这篇https://blog.csdn.net/u013555719/article/details/77679964

以上按batch读取的话,最后会自动在前面添加一个维度,比如数据的维度是[100],batch_size是10,那么读取出来的shape就是[10,100]

除此之外使用tf.train.string_input_producer创建文件名队列后,其实系统其实还是处于“停滞状态”的,只有tf.train.start_queue_runners之后,才会启动填充队列的线程即


  
  
  1. # 创建一个协调器,管理线程
  2. coord = tf.train.Coordinator()
  3. # 启动QueueRunner, 此时文件名队列已经进队
  4. threads = tf.train.start_queue_runners(sess=sess, coord=coord)

程序进行完后记关闭线程:


  
  
  1. # 通知其他线程关闭
  2. coord.request_stop()
  3. # 其他所有线程关闭之后,这一函数才能返回
  4. coord. join(threads)

所以最后总结一下步骤:

一:定义 gen_tfrecord函数,是源数据转化为tfrecord文件(test and train)

二:定义 read_and_decode函数,读取tfrecord,获得数据(类如image)及其label

三:通过tf.train.batch或者tf.train.shuffle_batch将数据进行分批次的打包(batch为一个批次)

四:定义epoch数,即要利用源数据多少轮,代码中就是最外面的一个for

五:在每一个epoch下,将一个个batch feed给神经网络进行训练

注意:在训练前启动线程tf.train.start_queue_runners,不要程序一直会停留在“停滞状态”

为了更直观的了解上面过程,下面举个简单类子来说明一下:

注:类子主要来自于https://www.bilibili.com/video/av20542427/?p=1的up主。

首先准备一个数据集(验证码),

首先要生成验证码图片,在datasets目录下有gen_image.py用于生成验证码图片。这里可以通过下载或者爬虫获取各种数据集,笔者采用下面方法

需要安装captcha(这是一个生成验证码图片的库)

 pip install captcha
  
  

如果报错no module named setuptools可以参考

https://www.cnblogs.com/Mr-Rice/p/3960487.html

然后运行产生图片的脚本(gen_image.bat)


  
  
  1. python C:/Users/asus-/Desktop/captcha_demo/datasets/gen_image.py ^
  2. --output_dir C:/Users/asus-/Desktop/captcha_demo/datasets/images/ ^
  3. --Captcha_size 4 ^
  4. --image_num 1000 ^
  5. pause

--output_dir就是输出图片的存储路径

--Captcha_size就是识别码图片上面字符的个数

--image_num就是产生图片的数量,但是有可能少于这个数,因为有可能产生重复的随机数,会覆盖前面的

关于gen_image.py为:


  
  
  1. import tensorflow as tf
  2. from captcha.image import ImageCaptcha
  3. import random
  4. import sys
  5. FLAGS = tf.app.flags.FLAGS
  6. tf.app.flags.DEFINE_string( 'output_dir', '/ ', 'This is the saved directory of the picture')
  7. tf.app.flags.DEFINE_integer( 'Captcha_size', 3, 'This is the number of characters of captcha')
  8. tf.app.flags.DEFINE_integer( 'image_num', 1000, 'This is the number of pictures generated ,but less than image_num')
  9. #验证码内容
  10. Captcha_content = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
  11. # 生成字符
  12. def random_captcha_text():
  13. captcha_text = []
  14. for i in range(FLAGS.Captcha_size):
  15. ch = random.choice(Captcha_content)
  16. captcha_text.append(ch)
  17. return captcha_text
  18. # 生成字符对应的验证码
  19. def gen_captcha_text_and_image():
  20. image = ImageCaptcha()
  21. captcha_text = random_captcha_text()
  22. captcha_text = ''.join(captcha_text)
  23. captcha = image.generate(captcha_text)
  24. image.write(captcha_text, FLAGS.output_dir + captcha_text + '.jpg')
  25. def main(unuse_args):
  26. for i in range(FLAGS.image_num ):
  27. gen_captcha_text_and_image()
  28. sys.stdout.write( '\r>> Creating image %d/%d' % (i+ 1, FLAGS.image_num))
  29. sys.stdout.flush()
  30. sys.stdout.write( '\n')
  31. sys.stdout.flush()
  32. print( "Finish!!!!!!!!!!!")
  33. if __name__ == '__main__':
  34. tf.app.run()

运行后为:

转化图片为tfrecord格式

同样这里写了一个简单的脚本:


  
  
  1. python C:/Users/asus-/Desktop/captcha_demo/datasets/gen_tfrecord.py ^
  2. --dataset_dir C:/Users/asus-/Desktop/captcha_demo/datasets/images/ ^
  3. --output_dir C:/Users/asus-/Desktop/captcha_demo/datasets/ ^
  4. --test_num 10 ^
  5. --random_seed 0 ^
  6. pause

从上到下依次是数据集位置,tfrecord生成位置,测试集个数,随机种子(用于打乱数据集)


  
  
  1. import tensorflow as tf
  2. import os
  3. import random
  4. import math
  5. import sys
  6. from PIL import Image
  7. import numpy as np
  8. FLAGS = tf.app.flags.FLAGS
  9. tf.app.flags.DEFINE_string( 'dataset_dir', '/ ', 'This is the source directory of the picture')
  10. tf.app.flags.DEFINE_string( 'output_dir', '/ ', 'This is the saved directory of the picture')
  11. tf.app.flags.DEFINE_integer( 'test_num', 20, 'This is the number of test of captcha')
  12. tf.app.flags.DEFINE_integer( 'random_seed', 0, 'This is the random_seed')
  13. #判断tfrecord文件是否存在
  14. def dataset_exists(dataset_dir):
  15. for split_name in [ 'train', 'test']:
  16. output_filename = os.path.join(dataset_dir,split_name + '.tfrecords')
  17. if not tf.gfile.Exists(output_filename):
  18. return False
  19. return True
  20. #获取所有验证码图片
  21. def get_filenames_and_classes(dataset_dir):
  22. photo_filenames = []
  23. for filename in os.listdir(dataset_dir):
  24. #获取文件路径
  25. path = os.path.join(dataset_dir, filename)
  26. photo_filenames.append(path)
  27. return photo_filenames
  28. def int64_feature(values):
  29. if not isinstance(values, (tuple, list)):
  30. values = [values]
  31. return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
  32. def bytes_feature(values):
  33. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
  34. def image_to_tfexample(image_data, label0, label1, label2, label3):
  35. #Abstract base class for protocol messages.
  36. return tf.train.Example(features=tf.train.Features(feature={
  37. 'image': bytes_feature(image_data),
  38. 'label0': int64_feature(label0),
  39. 'label1': int64_feature(label1),
  40. 'label2': int64_feature(label2),
  41. 'label3': int64_feature(label3),
  42. }))
  43. #把数据转为TFRecord格式
  44. def convert_dataset(split_name, filenames, dataset_dir):
  45. assert split_name in [ 'train', 'test']
  46. with tf.Session() as sess:
  47. #定义tfrecord文件的路径+名字
  48. output_filename = os.path.join(FLAGS.output_dir,split_name + '.tfrecords')
  49. with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  50. for i,filename in enumerate(filenames):
  51. try:
  52. sys.stdout.write( '\r>> Converting image %d/%d' % (i+ 1, len(filenames)))
  53. sys.stdout.flush()
  54. #读取图片
  55. image_data = Image.open(filename)
  56. #根据模型的结构resize
  57. image_data = image_data.resize(( 224, 224))
  58. #灰度化
  59. image_data = np.array(image_data.convert( 'L'))
  60. #将图片转化为bytes
  61. image_data = image_data.tobytes()
  62. #获取label
  63. labels = filename.split( '/')[ -1][ 0: 4]
  64. num_labels = []
  65. for j in range( 4):
  66. num_labels.append(int(labels[j]))
  67. #生成protocol数据类型
  68. example = image_to_tfexample(image_data, num_labels[ 0], num_labels[ 1], num_labels[ 2], num_labels[ 3])
  69. tfrecord_writer.write(example.SerializeToString())
  70. except IOError as e:
  71. print( 'Could not read:',filename)
  72. print( 'Error:',e)
  73. sys.stdout.write( '\n')
  74. sys.stdout.flush()
  75. def main(unuse_args):
  76. if dataset_exists(FLAGS.output_dir):
  77. print( 'tfcecord file has been existed!!')
  78. else:
  79. #获得所有图片
  80. photo_filenames = get_filenames_and_classes(FLAGS.dataset_dir)
  81. #把数据切分为训练集和测试集,并打乱
  82. random.seed(FLAGS.random_seed)
  83. random.shuffle(photo_filenames)
  84. training_filenames = photo_filenames[FLAGS.test_num:]
  85. testing_filenames = photo_filenames[:FLAGS.test_num]
  86. #数据转换
  87. convert_dataset( 'train', training_filenames,FLAGS.dataset_dir)
  88. convert_dataset( 'test', testing_filenames, FLAGS.dataset_dir)
  89. print( 'Finish!!!!!!!!!!!!!!!!!')
  90. if __name__ == '__main__':
  91. tf.app.run()

运行后:

下面是读取tfrecord格式

我们读取的是test测试集,这里test中有十个样本(batch=1,即每次取一个样本)


  
  
  1. import tensorflow as tf
  2. import image_reader as ir
  3. BATCH_SIZE= 1
  4. image, label0, label1, label2, label3 = ir.read_and_decode( 'C:/Users/asus-/Desktop/captcha_demo/datasets/test.tfrecords')
  5. #使用shuffle_batch可以随机打乱
  6. image_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
  7. [image, label0, label1, label2, label3], batch_size =BATCH_SIZE,
  8. capacity = 50000, min_after_dequeue= 10000, num_threads= 1)
  9. with tf.Session() as sess:
  10. sess.run(tf.global_variables_initializer())
  11. coord = tf.train.Coordinator()
  12. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  13. for i in range( 30):
  14. b_image, b_label0, b_label1 ,b_label2 ,b_label3 = sess.run([image_batch, label_batch0, label_batch1, label_batch2, label_batch3])
  15. print( 'label:',b_label0, b_label1 ,b_label2 ,b_label3)
  16. coord.request_stop()
  17. coord.join(threads)

其中image_reader:


  
  
  1. import tensorflow as tf
  2. # 从tfrecord读出数据
  3. def read_and_decode(filename):
  4. # 根据文件名生成一个队列
  5. filename_queue = tf.train.string_input_producer([filename])
  6. reader = tf.TFRecordReader()
  7. # 返回文件名和文件
  8. _, serialized_example = reader.read(filename_queue)
  9. features = tf.parse_single_example(serialized_example,
  10. features={
  11. 'image' : tf.FixedLenFeature([], tf.string),
  12. 'label0': tf.FixedLenFeature([], tf.int64),
  13. 'label1': tf.FixedLenFeature([], tf.int64),
  14. 'label2': tf.FixedLenFeature([], tf.int64),
  15. 'label3': tf.FixedLenFeature([], tf.int64),
  16. })
  17. # 获取图片数据
  18. image = tf.decode_raw(features[ 'image'], tf.uint8)
  19. # tf.train.shuffle_batch必须确定shape
  20. image = tf.reshape(image, [ 224, 224])
  21. # 图片预处理
  22. image = tf.cast(image, tf.float32) / 255.0
  23. image = tf.subtract(image, 0.5)
  24. image = tf.multiply(image, 2.0)
  25. # 获取label
  26. label0 = tf.cast(features[ 'label0'], tf.int32)
  27. label1 = tf.cast(features[ 'label1'], tf.int32)
  28. label2 = tf.cast(features[ 'label2'], tf.int32)
  29. label3 = tf.cast(features[ 'label3'], tf.int32)
  30. return image, label0, label1, label2, label3

第一次我们使用tf.train.shuffle_batch运行后:

可以看到是乱序的,并不是依次取出

接下来我们用tf.train.batch来看一下结果:

可以看到,每十次循环一次,而且顺序不变!!

当实际项目中需要feed给神经网络时,在外面再加一个for循环用于epoch数即可!!!

如果想看类子的全部过程请看:

https://blog.csdn.net/weixin_42001089/article/details/81136954

参考:

https://blog.csdn.net/happyhorizion/article/details/77894055

https://blog.csdn.net/ying86615791/article/details/73864381

https://www.sohu.com/a/148331531_697750

前言:

猜你喜欢

转载自blog.csdn.net/weixin_37799689/article/details/106492006
今日推荐