MOOC笔记——用Tensorflow制作数据集

程序来源:人工智能实践:Tensorflow笔记

程序介绍:

    程序分为以下4个模块来执行数据集的制作以及提取,详细介绍见程序注释

    1.    write_tfRecord()        用于生成tfRecord文件

    2.    generate_tfRecord() 用于把生成的tfRecord文件保存到本地

    3.    read_tfRecord()        用于解析tfRecord文件

    4.    get_tfrecord()           用于批获取训练集或测试集的内容和标签

制作的图片展示(部分)和图片下载链接(待上传):


程序:

#coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import os

# 这是设置的路径,可以根据您的需要修改
image_train_path='./mnist_data_jpg/mnist_train_jpg_60000/'
label_train_path='./mnist_data_jpg/mnist_train_jpg_60000.txt'
tfRecord_train='./data/mnist_train.tfrecords'
image_test_path='./mnist_data_jpg/mnist_test_jpg_10000/'
label_test_path='./mnist_data_jpg/mnist_test_jpg_10000.txt'
tfRecord_test='./data/mnist_test.tfrecords'
data_path='./data'
# 设置长宽像素点个数
resize_height = 28
resize_width = 28

# 生成tfrecords文件
def write_tfRecord(tfRecordName, image_path, label_path):
    writer = tf.python_io.TFRecordWriter(tfRecordName)  # 新建一个writer
    num_pic = 0 
    f = open(label_path, 'r')
    contents = f.readlines()    # 一次全部读入,速度比较快
    f.close()   
    for content in contents:
        '''
        该目录下的文件下的txt内容为:
        0_5.jpg 5
        1_0.jpg 0
        2_4.jpg 4
        .......
        '''
        value = content.split() # 用空格分开
        img_path = image_path + value[0] 
        img = Image.open(img_path)
        img_raw = img.tobytes() # 转化为二进制文件
        labels = [0] * 10  
        labels[int(value[1])] = 1 # 设置标签位为1 
        
        # 用tf.train.Example的协议存储训练数据,训练数据的特征用键值对的形式表示
        example = tf.train.Example(features=tf.train.Features(feature={
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
                }))     # 把每张图片和标签封装到example中
        writer.write(example.SerializeToString())   # 将example序列化(把数据序列化成字符串存储)
        num_pic += 1 
        print ("the number of picture:", num_pic)
    writer.close() # 关闭writer
    print("write tfrecord successful")

# 产生数据集
def generate_tfRecord():
	isExists = os.path.exists(data_path)  # 判断路径是否存在
	if not isExists: # 如果不存在
		os.makedirs(data_path)   # 新建一个目录
		print ('The directory was created successfully')
	else:
		print ('directory already exists')
    # 生成tfRecords文件
	write_tfRecord(tfRecord_train, image_train_path, label_train_path)
	write_tfRecord(tfRecord_test, image_test_path, label_test_path)

# 解析tfrecords文件 
def read_tfRecord(tfRecord_path):
    # [tfRecord_path]为文件的路径,如果文件比较大可以写多个
    filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
    reader = tf.TFRecordReader() # 新建一个reader
    _, serialized_example = reader.read(filename_queue) # 将读出的每个样本保存在serialize_example中
    features = tf.parse_single_example(serialized_example, 
                                       features={
                                        'label': tf.FixedLenFeature([10], tf.int64), # 10分类写10
                                        'img_raw': tf.FixedLenFeature([], tf.string)
                                        })  # 解序列化
    img = tf.decode_raw(features['img_raw'], tf.uint8)  # 恢复img_raw 到 img
    img.set_shape([784])  # 把img的shape设为[1,784]
    img = tf.cast(img, tf.float32) * (1. / 255) # 归一化到0-1
    label = tf.cast(features['label'], tf.float32)  # 同时把label值也设为浮点型
    return img, label 

# 批获取训练集或测试集的内容和标签
def get_tfrecord(num, isTrain=True):
    if isTrain: # 获取训练集,isTrain参数设置为True
        tfRecord_path = tfRecord_train
    else:       # 获取测试集,isTrain参数设置为False
        tfRecord_path = tfRecord_test
    img, label = read_tfRecord(tfRecord_path)
    # 从总样本中顺序获取capactiy组数据,打乱顺序,每次输出batch_size组,用了2个线程
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                    batch_size = num,
                                                    num_threads = 2,
                                                    capacity = 1000,
                                                    min_after_dequeue = 700)
    return img_batch, label_batch

def main():
    generate_tfRecord()

if __name__ == '__main__':
    main()

通过运行该程序中的generate_tfRecord()模块,我们就可以在./data/路径下看到以下文件:


猜你喜欢

转载自blog.csdn.net/Invokar/article/details/80050519