版权声明:希望尊重辛苦的自学之旅 https://blog.csdn.net/sinat_42378539/article/details/83047631
使用tensorflow训练自己的数据集—制作数据集
想记录一下自己制作训练集并训练的过、希望踩过的坑能帮助后面入坑的人。
本次使用的训练集的是kaggle中经典的猫狗大战数据集(提取码:ufz5)。因为本人笔记本配置很差还不是N卡所以把train的数据分成了训练集和测试集并没有使用原数据集中的test。在tensorflow中使用TFRecord格式喂给神经网络但是现在官方推荐使用tf.data
但这个API还没看所以还是使用了TFRecord。
代码注释还挺清楚就直接上代码了。
import os
import tensorflow as tf
from PIL import Image
# 源数据地址
cwd = 'C:/Users/Qigq/Desktop/P_Data/kaggle/train'
# 生成record路径及文件名
train_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/train.tfrecords"
test_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/test.tfrecords"
# 分类
classes = {'cat','dog'}
def _byteslist(value):
"""二进制属性"""
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
def _int64list(value):
"""整数属性"""
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def create_train_record():
"""创建训练集tfrecord"""
writer = tf.python_io.TFRecordWriter(train_record_path) # 创建一个writer
NUM = 1 # 显示创建过程(计数)
for index, name in enumerate(classes):
class_path = cwd + "/" + name + '/'
l = int(len(os.listdir(class_path)) * 0.7) # 取前70%创建训练集
for img_name in os.listdir(class_path)[:l]:
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((128, 128)) # resize图片大小
img_raw = img.tobytes() # 将图片转化为原生bytes
example = tf.train.Example( # 封装到Example中
features=tf.train.Features(feature={
"label":_int64list(index), # label必须为整数类型属性
'img_raw':_byteslist(img_raw) # 图片必须为二进制属性
}))
writer.write(example.SerializeToString())
print('Creating train record in ',NUM)
NUM += 1
writer.close() # 关闭writer
print("Create train_record successful!")
def create_test_record():
"""创建测试tfrecord"""
writer = tf.python_io.TFRecordWriter(test_record_path)
NUM = 1
for index, name in enumerate(classes):
class_path = cwd + '/' + name + '/'
l = int(len(os.listdir(class_path)) * 0.7)
for img_name in os.listdir(class_path)[l:]: # 剩余30%作为测试集
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((128, 128))
img_raw = img.tobytes() # 将图片转化为原生bytes
# print(index,img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label":_int64list(index),
'img_raw':_byteslist(img_raw)
}))
writer.write(example.SerializeToString())
print('Creating test record in ',NUM)
NUM += 1
writer.close()
print("Create test_record successful!")
def read_record(filename):
"""读取tfrecord"""
filename_queue = tf.train.string_input_producer([filename]) # 创建文件队列
reader = tf.TFRecordReader() # 创建reader
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
}
)
label = features['label']
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [128, 128, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 归一化
label = tf.cast(label, tf.int32)
return img, label
def get_batch_record(filename,batch_size):
"""获取batch"""
image,label = read_record(filename)
image_batch,label_batch = tf.train.shuffle_batch([image,label], # 随机抽取batch size个image、label
batch_size=batch_size,
capacity=2000,
min_after_dequeue=1000)
return image_batch,label_batch
def main():
create_train_record()
create_test_record()
if __name__ == '__main__':
main()
### 调用示例 ###
# create_train_record(cwd,classes)
# create_test_record(cwd,classes)
# image_batch,label_batch = get_batch_record(filename,32)
# init = tf.initialize_all_variables()
#
# with tf.Session() as sess:
# sess.run(init)
#
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(sess=sess,coord=coord)
#
# for i in range(1):
# image,label = sess.run([image_batch,label_batch])
# print(image.shape,1)
#
#
# coord.request_stop()
# coord.join(threads)
下一篇将介绍定义神经网络
如有错误望多多指教~~