tensorflow 的dataset读取tfrecord格式数据,转为图片输出

直接上代码:

# -*- coding: utf-8 -*-
import tensorflow as tf

from PIL import Image

#1 create dataset ,创建数据集
input_files=['../tfrecord/traindata-000.tfrecord','../tfrecord/traindata-001.tfrecord']
dataset=tf.data.TFRecordDataset(input_files)

#2 parser dataset,因为是tfrecord所以要定义解析含数来解析
def parser(record):
    features=tf.parse_single_example(
            record,
            features={#和原来生成tfrecord时候要对应相同的
                'label': tf.FixedLenFeature([], tf.int64),
                'img_raw' : tf.FixedLenFeature([], tf.string),
                'img_width': tf.FixedLenFeature([], tf.int64),
                'img_height': tf.FixedLenFeature([], tf.int64),
                }
            )
    return features["label"],features['img_raw'],features['img_width'],features['img_height']
dataset=dataset.map(parser)#接受的参数是一个函数

#3 use iterator get value 定义迭代器,才能获得起张量
iterator=dataset.make_one_shot_iterator()
label,img,width,height=iterator.get_next()
#担心数据不一致,所以转化一次
image = tf.decode_raw(img, tf.uint8)
height = tf.cast(height,tf.int32)
width = tf.cast(width,tf.int32)
label = tf.cast(label, tf.int32)
channel = 3
image = tf.reshape(image, [height,width,channel])

with tf.Session() as sess:
    for i in range(300):
        #test -->yes
        #print(sess.run([label,img,width,height]))
        #print('=======================================')
        #数据集所有的参数已经确定,所以不要参数初始化过程,但是要有值,张量必须run()一下
        single,l = sess.run([image,label])
        pic=Image.fromarray(single, 'RGB')

        pic.save('./tfrecord2pic'+'/'+str(i)+'_''Label_'+str(l)+'.jpg')#存下图片

========提示

红色字体为要修改的地方。

其实和我的上一篇tfrecord格式数据转化为图片,一样。

只是这里用过dataset来读取而已,代码很多一样,只是解析部分不同。

效果如下:


       
  

猜你喜欢

转载自blog.csdn.net/fu6543210/article/details/80269215