阿里云04:TensorFlow 02 数据读取与神经网络

上文,本文阿里云人工智能学习路线中Tensorflow快速入门子课程的后两章内容的笔记,原课程是2018年录制的,代码是TensorFlow1.x写的,很多方法在2.x版本已经弃用,虽然使用上文中提到的方法能够运行1.x版本的语法,但不是长久之计。因此,本文将代码移植到TensorFlow2.1环境下,并将弃用的方法替换。同时,将课程中的内容做了扩展。

3.数据读取

3.1文件读取流程

多线程+队列

  1. QueueRunner:基于队列的输入管道从TensorFlow图形开头的文件读取数据。
  2. Feeding:每运行一步时,python代码提供数据。
  3. 预加载数据:TensorFlow图中的张量包含所有的数据(对于小数据集)。

3.1.1 通用文件读取流程

  • 第一阶段:构造文件名队列
  • 第二阶段:读取与解码
  • 第三阶段:批处理并手动开启线程
    注意:这些操作需要启动运行这些队列操作的线程,以便我们在进行文件读取的过程中能够顺利进行入队出队操作。数据读取流程图:
    在这里插入图片描述

1. 构造文件名队列
将需要读取的文件的文件名放入文件名队列。
tf.train.string_input_producer(string_tensor,shuffle=True) 已弃用!

  • string_tensor:含有文件名+路径的1阶张量;
  • num_epochs:过几遍数据
  • return 文件队列

替换为:

tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)

2. 读取与解码

  • 1)读取文件内容

    • tf.TextLineReader():读取问问文件逗号分隔值(CSV)格式,默认按行读取
      return:读取器实例
    • tf.WholeFileReader():读取图片
      return:读取器实例
    • tf.FixedLengthRecordReaders(record_bytes):二进制文件
      要读取每个记录是固定数量字节的二进制文件
      record_bytes:整型,指定每次读取(一个样本)的字节数
      return:读取器实例
    • tf.TFRecordReader():读取TFRecords文件
      return:读取器实例
      它们有共同的读取方法:key,value = read(file_queue),并且都会返回一个Tensor元组。
    • key:文件名
    • value:一个样本
      由于默认只会读取一个样本,所以如果想要进行批处理,需要使用tf.train.batchtf.train.shuffle_batch进行批处理操作,便于以后制定每批次多个样本的训练。
  • 2)内容解码
    读取不同类型的文件,也应该对读取到的不同类型的内容进行相应的解码操作,解码成统一的Tensor格式。解码阶段,默认所有的内容都解码成tf.uint8类型,如果要转换成指定类型,则需要使用tf.cast()进行相应的转换。

    • tf.decode_csv:解码文本文件内容
    • tf.image.decode_jpeg(contents)
      • 将jpeg编的图像解码为uint8张量
      • return:uint8张量,3-D形状[height,width,channels]
    • tf.image.decode_png(contents)
      • 将png编码的图像解码为uint8张量
      • return:张量类型,3-D形状[height,width,channels]
    • tf.decode_raw:解码二进制文件内容
      • tf.FixedLengthRecordReader搭配使用,二进制读取为uint8类型

3. 批处理

解码之后,可以直接获取默认的一个样本内容。但如果想要获取多个样本,需要加入到新的队列进行批处理。

tf.compat.v1.train.batch(tensors, batch_size, num_threads=1, capacity=32, name=None) 已弃用!

  • 读取指定大小(个数)的张量;
  • tensors:可以是包含张量的列表,批处理的内容放到列表中
  • batch_size:从队列中读取的批处理的大小
  • num_threads:进入队列的线程数
  • capacity:整数,队列中元素的最大数量
  • return :tensors

替换为:

tf.data.Dataset.batch(batch_size, drop_remainder=False)
  • batch_size:一个tf.int64标量tf.Tensor,表示此数据集要在单个批次中合并的连续元素的数量。
  • drop_remainder:(可选)一个tf.bool标量tf.Tensor,表示在batch_size元素少于元素的情况下是否应删除最后一批 ;默认行为是不删除较小的批次。

tf.train.shuffle_batch 已弃用!
替换为:

tf.data.Dataset.shuffle(
    buffer_size, seed=None, reshuffle_each_iteration=None
)
  • buffer_size:一个tf.int64标量tf.Tensor,表示此数据集中要从中采样新数据集的元素数。
  • seed:(可选)tf.int64标量tf.Tensor,表示将用于创建分布的随机种子。请参阅 tf.compat.v1.set_random_seed。
  • reshuffle_each_iteration:(可选)布尔值,如果为true,则表示每次迭代数据集时都应进行伪随机重排。(默认为True)

3.1.2 线程操作

tf.compat.v1.train.QueueRunner(
    queue=None, enqueue_ops=None, close_op=None, cancel_op=None,
    queue_closed_exception_types=None, queue_runner_def=None, import_scope=None
)

队列是一种方便的TensorFlow机制,可使用多个线程异步计算张量。例如,在规范的“输入读取器(Input Reader)”设置中,一组线程在队列中生成文件名。第二组线程从文件中读取记录,对其进行处理,并将张量排入第二个队列;第三组线程使这些输入记录出队以构造批次并通过训练操作运行它们。

每个QueueRunner都负责一个阶段,tf.train.start_queue_runners 函数会要求图中的每个QueueRunner启动他的运行队列操作的线程。(这些操作需要在会话中开启)

tf.compat.v1.train.start_queue_runners(
    sess=None, coord=None, daemon=True, start=True,
    collection=tf.GraphKeys.QUEUE_RUNNERS
)
  • sess:Session用于运行队列操作。默认为默认会话。
  • coord:可选,Coordinator用于协调启动的线程。
  • daemon:线程是否应标记为daemons,表示它们不阻止程序退出。
  • start:设置为False仅创建线程,而不启动它们。
  • collection:一个GraphKey指定图形集合以从中获取队列运行器。默认为GraphKeys.QUEUE_RUNNERS。

3.2 图片数据

3.2.1 图像基本知识

  • 特征提取
    文本 -> 数值(二维数组shape(n_samples,m_faetures))
    字典 -> 数值(二维数组shape(n_samples,m_faetures))
    图片 ->
  1. 图片三要素
    组成一张图片特征值是所有的像素值,图片有三个维度长、宽、通道(channel)数
  • 1)灰度图[长,宽,1]
    每一个像素点为[0,255]的数值,越接近于0,图片越黑。
  • 2)彩色图[长,宽,3]
    每个像素点用3个[0,255]的数值描述。
  1. 张量的形状
    Tensor(指令名称,shape,dtype)
    一张图片 shape = (height, width, channerls)
    多张图片 shape = (batch, height, width, channels)

3.2.2 图片特征值处理

为什么要缩放图片到统一大小?
在进行图像识别的时候,每个图片样本的特征数量要保持相同。所以需要将所有图片张量大小统一装换。另一方面,如果图片的像素量太大,通过这种方法适当减少像素的数量,减少训练的计算开销。

tf.image.resize_images(images,size) 已弃用!
替换为:

tf.image.resize(
    images, size, method=ResizeMethod.BILINEAR, preserve_aspect_ratio=False,
    antialias=False, name=None
)
  • images:形状的4-D张量[batch, height, width, channels]或形状的3-D张量[height, width, channels]。
  • size:2个元素的一维int32张量:new_height, new_width。图片的新尺寸。
  • method:ResizeMethod。默认为bilinear。
  • preserve_aspect_ratio:是否保留长宽比。如果设置了此项,则将images在size保留原始图像的纵横比的同时将其调整为适合的尺寸。如果图像size大于当前尺寸,则按比例放大图像 image。默认为False。
  • antialias:对图像进行下采样时是否使用抗混叠滤波器。
  • name:此操作的名称(可选)。

3.2.3 数据格式

  • 存储:uint(节约空间)
  • 矩阵运算:float32(提高精度)

编程过程中,要注意转换!

3.2.4 案例:图片读取

  • 第一阶段:构造文件名队列
  • 第二阶段:读取与解码
  • 第三阶段:批处理并手动开启线程

TF2中运行TF1的代码:

import tensorflow as tf 
tf.compat.v1.disable_eager_execution()
import os
def picture_read():
    """
    图片读取案例
    """
    filename = os.listdir("D:/AliyunEDU/cats vs dogs")
    # 拼接路径+文件名
    file_list = [os.path.join("D:/AliyunEDU/cats vs dogs/",file) for file in filename]
    #print("file_list:{}".format(fiel_list))
    
    # 1.构造文件名队列
    file_queue = tf.compat.v1.train.string_input_producer(file_list)
    #tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)

    # 2.读取与解码
    ## 读取阶段 uint8
    reader = tf.compat.v1.WholeFileReader()
    # key为文件名,value为一张图片原始编码形式
    key,value = reader.read(file_queue)
    #print(" key:{} \n value: {} \n" .format(key,value))
    
    ## 解码阶段 变为float32
    image = tf.compat.v1.image.decode_jpeg(value)
    #print("image:\n",image)
    
    # 图像的形状,类型修改
    image_resized = tf.compat.v1.image.resize_images(image,[200,200])
    
    ## 静态形状自改
    image_resized.set_shape(shape=[200,200,3])
    print("image_resized_new:\n",image_resized)
    
    # 3.批处理
    image_batch = tf.compat.v1.train.batch([image_resized], batch_size=100, num_threads=1, capacity=100)
    
    # 开启会话
    with tf.compat.v1.Session() as sess:
        # 开启线程
        ## ①线程协调员
        coord = tf.compat.v1.train.Coordinator()
        threads = tf.compat.v1.train.start_queue_runners(sess=sess,coord=coord)
        
        key_new,value_new,image_new,image_resized_new,image_batch_new = sess.run([key, value, image, image_resized, image_batch])
        #print(" key_new:{}\n value_new:{}\n".format(key_new,value_new))
        
        # 查看解码后的图片的值
        print(" image_new:\n",image_new)
        # 查看resize之后的数据
        print("image_resized_new:\n",image_resized_new)
        # 查看批处理之后的数据
        print("image_batch_new:\n",image_batch_new)
        
        ## ②回收线程
        coord.request_stop()
        coord.join(threads)
    
picture_read()    

输出:

 image_new:
 [[[ 86  67  71]
  [ 84  65  69]
  [ 80  61  65]
  ...
  [ 48  33  38]
  [ 49  34  39]
  [ 50  35  40]]

 [[ 80  61  65]
  [ 79  60  64]
  [ 76  57  61]
  ...
  [215 203 191]
  [214 202 190]
  [213 201 189]]]
image_resized_new:
 [[[ 86.        67.        71.      ]
  [ 78.515     59.515     63.515   ]
  [ 75.979996  59.98      60.98    ]
  ...
  [ 50.880127  36.880127  36.39511 ]
  [ 56.900208  42.900208  42.900208]
  [ 48.504974  33.504974  38.504974]]

 [[ 78.25      59.25      63.25    ]
  [ 75.44312   56.443127  60.443127]
  [ 71.2575    55.2575    56.2575  ]
  ...
 
  [213.58685  201.58685  189.58685 ]
  [217.25998  205.25998  193.25998 ]
  [214.12003  202.12003  190.12003 ]]

 [[ 60.5       40.5       51.5     ]
  [ 51.997498  32.4925    40.0075  ]
  [ 47.0275    31.0275    31.0575  ]
  ...
  [215.       203.       191.      ]
  [214.74002  202.74002  190.74002 ]
  [213.       201.       189.      ]]

 [[ 59.125     39.125     50.125   ]
  [ 55.22875   35.723747  43.238747]
  [ 56.1475    40.1475    40.1775  ]
  ...
  [212.75     200.75     188.75    ]
  [213.00125  201.00125  189.00125 ]
  [213.62003  201.62003  189.62003 ]]]
image_batch_new:
 [[[[140.       134.       108.      ]
   [140.       134.       109.45    ]
   [140.       134.       110.      ]
   ...
   [113.        85.        61.      ]
   [113.        85.        61.      ]
   [113.        85.        61.      ]]

  [[140.       134.       109.46    ]
   [140.       134.       109.8515  ]
   [140.       134.       110.      ]
   ...
   [113.        85.        61.      ]
   [113.        85.        61.      ]
   [113.        85.        61.      ]]

  [[140.       133.54     112.92    ]
   [140.       133.6665   112.369995]
   [140.       134.       110.92    ]
   ...
   [113.        85.        61.      ]
   [113.        85.        61.      ]
   [113.        85.        61.      ]]

  [[186.       180.       158.      ]
   [190.9      184.9      162.9     ]
   [200.7      194.7      172.7     ]
   ...
   [136.47504  118.650024  97.44549 ]
   [137.55002  117.349945  93.49844 ]
   [137.275    115.45001   88.17502 ]]]
   
   ...
   
 [[[116.        88.        76.      ]
   [122.99      94.99      82.99    ]
   [124.        94.        83.      ]
   ...
   [121.514984  94.485016  80.      ]
   [121.99002   94.00998   80.01996 ]
   [117.99005   92.504974  82.00995 ]]
   ...
  [[137.65002   99.650024  88.650024]
   [142.16064  104.16064   91.16064 ]
   [141.76129  103.76129   90.76129 ]
   ...
   [109.90308  117.90308  130.90308 ]
   [104.       112.       125.      ]
   [102.495026 108.495026 122.495026]]]]

.jpg格式图片数据说明

[[[227 184 178]
  [179 136 130]
  [183 140 134]
  ...
  [179 169 160]
  [175 165 156]
  [172 162 153]]

 [[177 134 128]
  [123  80  74]
  [118  75  69]
  ...
  [124 114 105]
  [120 110 101]
  [117 107  98]]
  ...

在这里插入图片描述

3.3 二进制文件读取

3.3.1 CIFAR-10数据集说明

二进制版本数据集的格式:<1×标签> <3072×像素>
第一个字节是第一个图像的标签,它是一个0-9范围内的数字。接下来的3072个字节是图像像素的值。前1024个字节是红色通道值,下1024个绿色,最后1024个蓝色。值以行优先顺序存储,因此前32个字节是图像第一行的红色通道值。

3.3.2 二进制数据读取

流程分析:

  1. 构造文件名队列
  2. 读取与解码
    reader = tf.compat.v1.FixedLengthRecordReader(3073)
    key,value = reader.read(file_queue)
    decoded = tf.decode_raw(value, tf.uint8)
  • 对tensor对象进行切片以截取标签和图片
  • 改变图像的形状(tensorflow图像的表示习惯收为通道数、长、宽)
  • 转置将图片的顺序调整为height、width、channels(reshape之后涉及到NHWCNCHW转换的问题)ndarray.T 转置 行变列,列变行
  1. 批处理
"""
本文代码适用二进制格式的CIFAR-10数据
"""
CIFAR-10数据下载:
http://www.cs.toronto.edu/~kriz/cifar.html
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

class Cifar(object):
    def __init__(self):
        # 初始化操作
        self.height = 32
        self.width = 32
        self.channels = 3
        
        # 字节数
        self.image_bytes = self.height * self.width * self.channels
        self.label_bytes = 1
        self.all_bytes = self.label_bytes + self.image_bytes
        
    def read_and_decode(self,file_list):
        # 1.构造文件名队列
        file_queue = tf.compat.v1.train.string_input_producer(file_list)
        
        # 2.读取与解码
        ## 读取阶段
        reader = tf.compat.v1.FixedLengthRecordReader(self.all_bytes)
        key,value = reader.read(file_queue)
        print("key:\n{}\n value:\n{}\n".format(key,value))
        
        ## 解码阶段
        decoded = tf.compat.v1.decode_raw(value, tf.uint8)
        
        ## 1_将目标值和特征值切片切开
        label = tf.slice(decoded, [0], [self.label_bytes])
        image = tf.slice(decoded, [1], [self.image_bytes])
        
        ## 2_调整图片形状,以符合tensor的输入要求
        image_reshaped = tf.reshape(image, shape=[self.channels,self.height,self.width])
        
        ## 3_转置 将图片调整为 HWC
        image_transposed = tf.transpose(image_reshaped,[1,2,0])
        print("image_reshaped:{}\n image_tansposed:{}\n".format(image_reshaped,image_transposed))
        
        ## 4_调整图像类型 uint8->float32
        image_cast = tf.cast(image_transposed, tf.float32)
        
        # 3.批处理
        label_batch,image_batch = tf.compat.v1.train.batch([label,image_cast],batch_size=100,num_threads=1,capacity=100)
        print(label_batch)
        
        # 开启会话
        with tf.compat.v1.Session() as sess:
            # 开启线程
            coord = tf.compat.v1.train.Coordinator()
            threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
            
            key_new, value_new, decoded_new, label_new, image_new,image_reshaped_new,image_transposed_new = sess.run(
                [key, value, decoded, label, image, image_reshaped, image_transposed]
            )
            label_value, image_value = sess.run([label_batch, image_batch])
            print("decoded_new:\n",decoded_new)
            
            # 回收线程
            coord.request_stop()
            coord.join(threads)
            
        return None
        
file_name = os.listdir("D:/Project/Data/cifar-10-bin")
# 构建文件名路径列表
file_list = [os.path.join("D:/Project/Data/cifar-10-bin/", file) for file in file_name if file[-3:] == "bin"]

# 实例化Cifar
cifar = Cifar()
cifar.read_and_decode(file_list)

输出

key:
Tensor("ReaderReadV2_1:0", shape=(), dtype=string)
 value:
Tensor("ReaderReadV2_1:1", shape=(), dtype=string)

image_reshaped:Tensor("Reshape:0", shape=(3, 32, 32), dtype=uint8)
 image_tansposed:Tensor("transpose:0", shape=(32, 32, 3), dtype=uint8)

Tensor("batch_1:0", shape=(100, 1), dtype=uint8)
decoded_new:
 [  8  98  91 ... 125 132 138]

tf.slice 方法说明

tf.slice( input_, begin, size, name=None )
  • input_: A Tensor.
  • begin: An int32 or int64 Tensor.
  • size: An int32 or int64 Tensor.
  • name: A name for the operation (optional).

3.4 TFRecords文件

3.4.1 什么是TFRecords文件?

是一种二进制文件,虽然它不如其他格式的数据好理解,但是它能更好的利用内容,并且不需要单独的标签,即样本和样本标签是绑定在一起的。

使用步骤:

  • 1)获取数据
  • 2)将数据填入到example协议内存块(protocol buffer)
  • 3)将协议内存块序列化为字符串,并且通过tf.python_io.TFRecordWriter写入到TFRecord文件中。
  • 4)开启会话
  • 5)手动开启线程

3.4.2 Example 结构解析

cifar10

  • 特征值 - image - 3072个字节
  • 目标值 - label - 1个字节
  example = tf.compat.v1.train.Example(features=tf.compat.v1.train.Features(feature={
                    "image":tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=[image])),
                    "label":tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=[label])),
                }))

example结构:

features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}
- tf.train.Example()
  - 写入tfrecords文件
  - features:tf.train.Features类型的特征实例
- tf.train.Features()
  - 构建每个样本的信息键值对
  - features:字典数据,key为要保存的名字
  - value为tf.train.Feature实例
  - return:Features类型
- tf.train.Feature()
  - options:例如
    - bytes_list = tf.train.BytesList(value=[Bytes])
    - int64_list = tf.train.int64List(value[Value])
  - 支持输入的类型如下:
  - tf.train.int64List(value=[Value])
  - tf.train.BytesList(value=[Bytes])
  - tf.train.FloatList(value=[Value])

3.4.3 CIFAR10数据存入TFRecords文件

  • 构造存储实例,tf.python_io.TFRecordWriter(path)
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import os


class Cifar(object):
    def __init__(self):
        # 初始化操作
        self.height = 32
        self.width = 32
        self.channels = 3
        
        # 字节数
        self.image_bytes = self.height * self.width * self.channels
        self.label_bytes = 1
        self.all_bytes = self.label_bytes + self.image_bytes
        
    def read_binary(self):
        # 1.构造文件名队列
        file_name = os.listdir("D:/Project/Data/cifar-10-bin")
        # 构建文件名路径列表
        file_list = [os.path.join("D:/Project/Data/cifar-10-bin/", file) for file in file_name if file[-3:] == "bin"]
        file_queue = tf.compat.v1.train.string_input_producer(file_list)
        
        # 2.读取与解码
        ## 读取阶段
        reader = tf.compat.v1.FixedLengthRecordReader(self.all_bytes)
        key,value = reader.read(file_queue)
        print("key:\n{}\n value:\n{}\n".format(key,value))
        
        ## 解码阶段
        decoded = tf.compat.v1.decode_raw(value, tf.uint8)
        
        ## 1_将目标值和特征值切片切开
        label = tf.slice(decoded, [0], [self.label_bytes])
        image = tf.slice(decoded, [1], [self.image_bytes])
        
        ## 2_调整图片形状,以符合tensor的输入要求
        image_reshaped = tf.reshape(image, shape=[self.channels,self.height,self.width])
        
        ## 3_转置 将图片调整为 HWC
        image_transposed = tf.transpose(image_reshaped,[1,2,0])
        print("image_reshaped:{}\n image_tansposed:{}\n".format(image_reshaped,image_transposed))
        
        ## 4_调整图像类型 uint8->float32
        image_cast = tf.cast(image_transposed, tf.float32)
        
        # 3.批处理
        label_batch,image_batch = tf.compat.v1.train.batch([label,image_cast],batch_size=100,num_threads=1,capacity=100)
        print(label_batch)
        
        # 开启会话
        with tf.compat.v1.Session() as sess:
            # 开启线程
            coord = tf.compat.v1.train.Coordinator()
            threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
            
            key_new, value_new, decoded_new, label_new, image_new,image_reshaped_new,image_transposed_new = sess.run(
                [key, value, decoded, label, image, image_reshaped, image_transposed]
            )
            label_value, image_value = sess.run([label_batch, image_batch])
            print("decoded_new:\n",decoded_new)
            
            # 回收线程
            coord.request_stop()
            coord.join(threads)
            
        return image_value, label_value
    
    def write_to_tfrecords(self, image_batch, label_batch):
        """
        将样本的特征值写入TFRecords文件
        """
        with tf.compat.v1.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
            # 循环构造example对象,并序列化写入文件
            for i in range(100):
                image = image_batch[i].tostring() #bytes类型
                label = label_batch[i][0] #取出整型单值
                #print("label:{}\n image:{}".format(label,image))
                example = tf.compat.v1.train.Example(features=tf.compat.v1.train.Features(feature={
                    "image":tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=[image])),
                    "label":tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=[label])),
                }))
                #example.SerializeToString()
                # 将序列化后的example下入example文件
                writer.write(example.SerializeToString())
                                                                      
        return None

    
    def read_tfrecords(self):
        # 1.构造文件名队列
        file_queue = tf.compat.v1.train.string_input_producer(["cifar10.tfrecords"])
        
        # 2.读取与解码
        reader = tf.compat.v1.TFRecordReader()
        key,value = reader.read(file_queue)
        
        # 解析 example
        feature = tf.compat.v1.parse_single_example(value, features={
            "image":tf.compat.v1.FixedLenFeature([], tf.string),
            "label":tf.compat.v1.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        
        ## 解码
        image_decoded = tf.compat.v1.decode_raw(image, tf.uint8)
        print("image_decoded:",image_decoded)
        
        ## 图像形状调整
        image_reshaped = tf.compat.v1.reshape(image_decoded, [self.height, self.width, self.channels])
        
        # 3.构造批处理队列
        image_batch,label_batch = tf.compat.v1.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
        
        
        with tf.compat.v1.Session() as sess:
            
            # 开启线程
            coord = tf.compat.v1.train.Coordinator()
            threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
            
            image_value, label_value = sess.run([image_batch, label_batch])
            print("image_value:\n",image_value)
            
            # 回收线程
            coord.requset_stop()
            coord.join(threads)
            
        return None

# 实例化Cifar
cifar = Cifar()
image_value, label_value = cifar.read_binary()
cifar.write_to_tfrecords(image_value, label_value)

3.4.4 读取TFRecords文件API

  • tf.parse_single_example(serialized, features=None,name=None)

    • 解析一个单一的Example原型。
    • serialized:标量字符串Tensor,一个序列化的Example。
    • features:dict字典数据,键为读取的名字,值为FixedLenFeature。
    • return:一个键值对组成的字典,键为读取的名字。
  • tf.FixedLenFeature(shape,dtype)

    • shape:输入数据的形状,一般不指定,为空列表。
    • dtype:输入数据的类型。类型只能是float32,int64,string。

步骤:

  • 1)构造文件名队列
  • 2)读取和解码
    • 读取
      解析example
  feature = tf.compat.v1.parse_single_example(values, features={
  "image":tf.compat.v1.FixedLenFeature([], tf.string)
  "label":tf.compat.v1.FixedLenFeature([], tf.int64)
  })
  image = feature["image"]
  label = feature["label"]
发布了122 篇原创文章 · 获赞 94 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/weixin_39653948/article/details/104940205