神经网络识别验证码并保存模型

import tensorflow as tf
import os


def get_weight(shape):
    w = tf.random_normal(shape=shape, mean=1.0, stddev=1.0)
    w = tf.Variable(w)
    return w


def get_bais(shape):
    b = tf.random_normal(shape=shape, mean=1.0, stddev=1.0)
    b = tf.Variable(b)
    return b


def read_model():
    # 1、构建文件队列
    file_queue = tf.train.string_input_producer(["./CaptchaRecognize/tfrecords/captcha.tfrecords"])

    # 2、构建阅读器,读取文件内容,默认一个样本
    reader = tf.TFRecordReader()

    # 读取内容
    key, value = reader.read(file_queue)

    # tfrecords格式example,需要解析
    features = tf.parse_single_example(value, features={
        "image": tf.FixedLenFeature([], tf.string),
        "label": tf.FixedLenFeature([], tf.string),
    })

    # 解码内容,字符串内容
    # 1、先解析图片的特征值
    image = tf.decode_raw(features["image"], tf.uint8)
    # 1、先解析图片的目标值
    label = tf.decode_raw(features["label"], tf.uint8)

    image_reshape = tf.reshape(image, [20, 80, 3])
    label_reshape = tf.reshape(label, [4])
    image_batch, label_batch = tf.train.batch([image_reshape, label_reshape], batch_size=100, num_threads=1,
                                              capacity=100)

    return image_batch, label_batch


def start_run():
    model_path = "./data/yanzhengma/data"
    image_batch, label_batch = read_model()
    # 计算模型
    with tf.variable_scope("model"):
        weights = get_weight([20 * 80 * 3, 26 * 4])
        bais = get_bais([26 * 4])
        image_reshape = tf.reshape(image_batch, [-1, 20 * 80 * 3])
        image_reshape = tf.cast(image_reshape, tf.float32)
        y_predict = tf.matmul(image_reshape, weights) + bais
        y_predict = tf.reshape(y_predict, [-1, 4, 26])
    # onehot真实值
    with tf.variable_scope("one-hot"):
        y_true = tf.one_hot(label_batch, depth=26, axis=1, on_value=1.0)
        y_true = tf.reshape(y_true, shape=[-1, 4, 26])
    # 求损失
    with tf.variable_scope("loss"):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
        # 梯度下降
    with tf.variable_scope("optimizer"):
        train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
        # 求准确率
    with tf.variable_scope("acc"):
        equal_list = tf.equal(tf.argmax(y_true, 2), tf.argmax(y_predict, 2))
        # equal_list  None个样本   [1, 0, 1, 0, 1, 1,..........]
        accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
        saver = tf.train.Saver()
    with tf.Session() as sess:
        print("进入会话")
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)

        if os.path.exists("./data/yanzhengma/checkpoint"):
            saver.restore(sess, model_path)

        print("初始化变量")
        for i in range(10000):
            sess.run(train_op)
            print("第%d次,准确率%f" % (i, accuracy.eval()))
        saver.save(sess, model_path)
        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    start_run()
发布了75 篇原创文章 · 获赞 9 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/sinat_40387150/article/details/90142864