写一个练手的验证码识别项目
数据集和完整代码我会传到我的下载资源
这篇文章是项目第一步,创建tfrecords文件
# 代码逻辑
1.读取图片文件
2.读取csv文件
3.处理一下读取好的csv文件到数字张量
4.写入tfrecords文件
1.读取图片文件
1.创建文件队列
2.构造阅读器取读取文件内容
3.选择相应的文件解码器取decode
4.要根据验证码的尺寸取setshape,因为读取过来的是一起读过来的
5.批处理数据
def get_image():
"""
获取验证码图片内容
:return: 批处理
"""
file_name = os.listdir("./yz/train10000")
# 构造路径加文件名
file_list = [os.path.join(FLAGS.captcha_dir, file) for file in file_name]
# 构造文件队列
file_queue = tf.train.string_input_producer(file_list, shuffle=False)
# 构造阅读器
reader = tf.WholeFileReader()
# 读取文件内容
key, value = reader.read(file_queue)
# 解码文件数据
image = tf.image.decode_jpeg(value)
image.set_shape([180, 60, 3])
# 批处理数据
image_batch = tf.train.batch([image], batch_size=10318, num_threads=1, capacity=10318)
return image_batch
2.得到csv文件
这里要先说明一下数据集的样式,他是jegp格式的,每个图片的文件名就是验证码的内容。所以我先读取了文件名建立了一个csv文件,在来读csv
建立csv文件
import os
import tensorflow as tf
# def get_name():
# """
# 得到训练集真实数据的函数
# :return: a
# """
print(os.listdir("./genpics/train/"))
a=list()
file_list=os.listdir("./yz/train10000")
with open("./data2.csv","w") as f:
for i in range(len(file_list)):
f.write(file_list[i][0:4]+"\n")# a.append()
# print(a)
读取csv
1.创建文件队列
2.创建文件阅读器
3.读文件
4.解码decode,注意格式
5.批处理
def get_label():
"""
获取验证码文件的标签数据,其实也就是获取真实值
:return: 真实值
"""
file_queue = tf.train.string_input_producer(["./data2.csv"], shuffle=False)
reader = tf.TextLineReader()
key, value = reader.read(file_queue)
records = [[1], ["None"]]
key, label = tf.decode_csv(value, record_defaults=records)
label_batch = tf.train.batch([label], batch_size=10318, num_threads=1, capacity=10318)
return label_batch
3.处理字符串标签张量
刚才读取的csv文件其实就是真实值,但是还是字符串,怎么能去比较呢,所以我们要把他处理成数字的类型,于是乎我们就要建立字典去把他们一一对应起来。
1.建立字符索引
2.键值反转
3.构建标签列表
4.对标签列表进行处理
def deal_label(label_str):
"""
处理字符串标签张量
"""
# 构建字符索引
num_letter = dict(enumerate(list(FLAGS.letter)))
# 键值反转
letter_num = dict(zip(num_letter.values(), num_letter.keys()))
# 构建标签列表
array = []
# 对标签数据进行处理
for string in label_str:
letter_list = []
# 修改编码方式为”utf-8“,并循环找到每张验证码字符对应的数字标记
for letter in string.decode('utf-8'):
letter_list.append(letter_num[letter])
array.append(letter_list)
# 将array转换成tensor类型
label = tf.constant(array)
return label
4.写tfrecoreds文件
把处理好的标签和image存到tfrecords文件
1.把标签转换成tf.uint8类型
2.建立tfrecords存储器
3.建立一个协议块,规定格式,注意这里的写法
4.写入,关闭文件。当然这里用with更好
def write_to_tfrecords(image_batch, label_batch):
"""
将图片内容和标签写入tfrecords文件
"""
# 转换类型
label_batch = tf.cast(label_batch, tf.uint8)
# 建立tfrecords存储器
writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir)
# 循环将图片上每一个example协议快,序列化后写入
for i in range(5000):
image_string = image_batch[i].eval().tostring()
label_string = label_batch[i].eval().tostring
example = tf.train.Example(feature=tf.train.Feature(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])),
"label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string]))
}))
writer.write(example.SerializeToString())
writer.close()
return None
main函数
def main():
# 获取当前的图片文件
image_bacth = get_image()
# 获取验证码文件中标签数据
label_batch = get_label()
print(image_bacth, label_batch)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
label_str = sess.run(label_batch)
# 处理字符串标签到数字张量
babel = deal_label(label_str)
# 写入到tfrecords文件中
write_to_tfrecords(image_bacth, label_batch)
coord.request_stop()
coord.join(threads)
author:[email protected] 欢迎交流