TensorFlow01: 二进制文件读取

实现代码:

# 读取文件列表
file_name = os.listdir("../data/cifar/")
file_list = [os.path.join("../data/ficar/",file) for file in file_name]
# 构造文件名度列
file_queue = tf.train.string_input_producer(file_list)
# 读取
reader = tf.FixedLengthRecordReader(32*32*3+1)
key, value = reader.read(file_queue)
print(value)
# 解码
decoded = tf.decode_raw(value, tf.uint8)
print(decoded)
# 将目标值和特征值切开
label = tf.slice(decoded, [0], [1])
image = tf.slice(decoded, [1], [32*32*3])
print("label:", label)
print("image:", image)
# 调整图片的形状
image_reshape = tf.reshape(image, shape=[3, 32, 32])
print("image_reshape:", image_reshape)
# 转置
image_transposed = tf.transpose(image_reshape, [1, 2, 0])
print("image_transposed:", image_transposed)
# 调整图像类型
image_cast = tf.cast(image_transposed, tf.float32)
# 批处理
label_batch,image_batch = tf.train.batch([label,image_cast], batch_size=100, num_threads=1, capacity=100)
print("label_batch:", label_batch)
print("image_batch:", image_batch)

with tf.Session() as sess1:
    # print(sess1.run(label_batch))
    # 开启线程
    print("----------")
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess1, coord=coord)
    print("threads:", threads)
    a, b = sess1.run([label_batch,image_batch])
    print("label_batch+++++:", a)
    print("image_batch+++++:", b)
    print("999999")
    # 回收线程
    coord.request_stop()
    coord.join(threads)

运行结果:

Tensor("ReaderReadV2:1", shape=(), dtype=string)
Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)
label: Tensor("Slice:0", shape=(1,), dtype=uint8)
image: Tensor("Slice_1:0", shape=(3072,), dtype=uint8)
image_reshape: Tensor("Reshape:0", shape=(3, 32, 32), dtype=uint8)
image_transposed: Tensor("transpose:0", shape=(32, 32, 3), dtype=uint8)
label_batch: Tensor("batch:0", shape=(100, 1), dtype=uint8)
image_batch: Tensor("batch:1", shape=(100, 32, 32, 3), dtype=float32

猜你喜欢

转载自www.cnblogs.com/jumpkin1122/p/11522000.html