import os import tensorflow as tf from PIL import Image import numpy as np cat_image_path='D:/软件/pycharmProject/wenyuPy/CatImage/' cat_tfrecords='D:/软件/pycharmProject/wenyuPy/CatImage/cat.tfrecords' writer=tf.python_io.TFRecordWriter(cat_tfrecords) label1=np.array([1,0,0]) label2=np.array([0,1,0]) label3=np.array([0,0,1]) labels=[label1,label2,label3] img1=Image.open('D:/软件/pycharmProject/wenyuPy/CatImage/1.jpg') img1 = img1.resize((256, 256)) img2=Image.open('D:/软件/pycharmProject/wenyuPy/CatImage/2.jpg') img2 = img2.resize((256, 256)) img3=Image.open('D:/软件/pycharmProject/wenyuPy/CatImage/3.jpg') img3 = img3.resize((256, 256)) images=[img1,img2,img3] for i in range(len(images)): features=tf.train.Features(feature={ 'catimage':tf.train.Feature(bytes_list=tf.train.BytesList(value=[images[i].tobytes()])), 'catlabel':tf.train.Feature(bytes_list=tf.train.BytesList(value=[labels[i].tobytes()])) } ) example=tf.train.Example(features=features) writer.write(example.SerializeToString()) writer.close() print('the tfrecords has benn writen')
import tensorflow as tf from PIL import Image input_tfrecords='D:/软件/pycharmProject/wenyuPy/CatImage/cat.tfrecords' #create a dataset cat_dataset=tf.data.TFRecordDataset(input_tfrecords) #定义解析函数来解析我们刚才所生成的tfrecords文件 def parser(record): features=tf.parse_single_example( record, features={ 'catimage':tf.FixedLenFeature([],tf.string), 'catlabel':tf.FixedLenFeature([],tf.string) }) return features['catimage'],features['catlabel'] #dataset中的map接收的是一个函数,dataset中的每个元素都会被当作这个函数的输入并且并将函数的返回值作为新的dataset cat_dataset=cat_dataset.map(parser) cat_iterator=cat_dataset.make_one_shot_iterator() #label=tf.cast(label,tf.int32) channel=3#定义的是RGB图像 with tf.Session() as sess: for i in range(3): img, label = cat_iterator.get_next() image = tf.decode_raw(img, tf.uint8) image = tf.reshape(image, [256, 256, 3]) single,l=sess.run([image,label]) pic=Image.fromarray(single,'RGB') pic.save('D:/软件/pycharmProject/wenyuPy/CatImage/tfrecordscat/'+str(i)+'.jpg') print('the picture has been take out')
我们之前是通过filename_queue=tf.train.string_input_producer([filename],shuffle=True),这条语句将文件名打乱生成一个文件名序列,其实我也不太懂为什么这样做,然后再用reader=tf.TFRecordReader()用来读取文件序列,我们读取到的是已经被序列化的二进制图像和label,然后再对其进行反序列化并且将二进制文件还原成我们原始的图像。但是我在运行的时候发现IDE出现了一个警告说TFRecordReader读取文件序列已经被tf.data.TFRecordDataset取代,然后查资料将程序改动了一下。1.首先我随机找了三张猫的图片放在了我电脑的D:\软件\pycharmProject\wenyuPy\CatImage这个目录下 2.然后我将生成好的猫图像的tfrecords文件放在了同样的目录下,地址可随意指定。 3.我们使用tf.data.TFRecordDataset(filename)来生成一个dataset,然后用这个dataset去生成一个迭代器,每次迭代一组image和label,存到指定的目录下即可。fromarray函数我也没有弄懂是什么意思,还有待解决。