import tensorflow as tf
import os
def get_weight(shape):
w = tf.random_normal(shape=shape, mean=1.0, stddev=1.0)
w = tf.Variable(w)
return w
def get_bais(shape):
b = tf.random_normal(shape=shape, mean=1.0, stddev=1.0)
b = tf.Variable(b)
return b
def read_model():
# 1、构建文件队列
file_queue = tf.train.string_input_producer(["./CaptchaRecognize/tfrecords/captcha.tfrecords"])
# 2、构建阅读器,读取文件内容,默认一个样本
reader = tf.TFRecordReader()
# 读取内容
key, value = reader.read(file_queue)
# tfrecords格式example,需要解析
features = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.string),
})
# 解码内容,字符串内容
# 1、先解析图片的特征值
image = tf.decode_raw(features["image"], tf.uint8)
# 1、先解析图片的目标值
label = tf.decode_raw(features["label"], tf.uint8)
image_reshape = tf.reshape(image, [20, 80, 3])
label_reshape = tf.reshape(label, [4])
image_batch, label_batch = tf.train.batch([image_reshape, label_reshape], batch_size=100, num_threads=1,
capacity=100)
return image_batch, label_batch
def start_run():
model_path = "./data/yanzhengma/data"
image_batch, label_batch = read_model()
# 计算模型
with tf.variable_scope("model"):
weights = get_weight([20 * 80 * 3, 26 * 4])
bais = get_bais([26 * 4])
image_reshape = tf.reshape(image_batch, [-1, 20 * 80 * 3])
image_reshape = tf.cast(image_reshape, tf.float32)
y_predict = tf.matmul(image_reshape, weights) + bais
y_predict = tf.reshape(y_predict, [-1, 4, 26])
# onehot真实值
with tf.variable_scope("one-hot"):
y_true = tf.one_hot(label_batch, depth=26, axis=1, on_value=1.0)
y_true = tf.reshape(y_true, shape=[-1, 4, 26])
# 求损失
with tf.variable_scope("loss"):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
# 梯度下降
with tf.variable_scope("optimizer"):
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
# 求准确率
with tf.variable_scope("acc"):
equal_list = tf.equal(tf.argmax(y_true, 2), tf.argmax(y_predict, 2))
# equal_list None个样本 [1, 0, 1, 0, 1, 1,..........]
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
saver = tf.train.Saver()
with tf.Session() as sess:
print("进入会话")
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
if os.path.exists("./data/yanzhengma/checkpoint"):
saver.restore(sess, model_path)
print("初始化变量")
for i in range(10000):
sess.run(train_op)
print("第%d次,准确率%f" % (i, accuracy.eval()))
saver.save(sess, model_path)
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
start_run()
神经网络识别验证码并保存模型
猜你喜欢
转载自blog.csdn.net/sinat_40387150/article/details/90142864
今日推荐
周排行