tensorflow 笔记 lstm 手写数字的识别

import tensorflow as tf
import os
import numpy as np
# import tensorflow.examples.tutorials.mnist as
from tensorflow.examples.tutorials.mnist import input_data
diminput = 28
dimhidden = 128
dimoutput = 28
nclass = 10
batch_size = 200
n_steps = 28
train_rate = 0.001
# train_steps = 10000
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"
# n_batch = mnist.train.num_examples // batch_size
#
# weights ={
#     'hidden' : tf.Variable(tf.random_normal([diminput, dimhidden])),
#     'out' : tf.Variable(tf.random_normal([dimhidden, dimoutput]))
# }
#
# biases = {
#     'hidden' : tf.Variable(tf.random_normal([dimhidden])),
#     'out' : tf.Variable(tf.random_normal([dimoutput]))
# }


def forward(data):
    # _x = tf.transpose(data, [1, 0, 2])
    # data.shape()
    _x = tf.reshape(data, [-1, 28, 28])
    # -->>[nsteps,batchsize,dim_input]==[28,num,28]
    _x = tf.transpose(_x, [1, 0, 2])
    _x = tf.reshape(_x, [-1, diminput])

    # w_input = tf.Variable(tf.random_normal([diminput, dimhidden]))
    w_input = tf.Variable(tf.truncated_normal(shape=[dimoutput, dimhidden],stddev = 0.1))
    # w_output = tf.Variable(tf.random_normal([dimhidden, nclass]))
    w_output = tf.Variable(tf.truncated_normal(shape=[dimhidden, nclass], stddev= 0.1))
    b_input = tf.Variable(tf.random_normal([dimhidden]))
    b_output = tf.Variable(tf.random_normal([nclass]))

    _H = tf.matmul(_x, w_input) + b_input
    H_split = tf.split(value=_H, num_or_size_splits=n_steps, axis=0)

    rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
    output, states = tf.nn.static_rnn(cell=rnn_cell, inputs= H_split, dtype= tf.float32)
    # tf.nn.dynamic_rnn(rnn_c)
    h = tf.matmul(output[-1], w_output) + b_output
    return { "output": output, "h": h}


def backward(mnist,n_batch):
    x = tf.placeholder(dtype=tf.float32, shape=[None, 784])
    y = tf.placeholder(dtype=tf.float32, shape=[None, nclass])
    y_pre = forward(x)["h"]
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_pre, labels=y))
    correct = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct, np.float32))
    opt = tf.train.GradientDescentOptimizer(train_rate).minimize(cost)
    # opt = tf.train.AdamOptimizer(train_rate).minimize(cost)
    # testx,testy = mninst.test.next_batch(batch_size)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        # train_x, train_y = mnist.train.next_batch(batch_size)
        for i in range(n_batch* 100):
            train_x, train_y = mnist.train.next_batch(batch_size)
            # print i
            _cost,_opt= sess.run([cost,opt], {x: train_x, y: train_y})
            if i % 100 == 0:
                _cost,_accuracy = sess.run([cost, accuracy], {x: train_x, y: train_y})
                print i, _cost, _accuracy
            if i == n_batch-1:
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME))

def main():
    mnist = input_data.read_data_sets("./data/", one_hot=True)
    n_batch = mnist.train.num_examples / batch_size
    backward(mnist, n_batch)



if __name__ == '__main__':
    main()
#!/usr/bin/python
# coding:utf-8 
# 19-6-13 下午5:21
# @File    : app.py
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import minRNN
# import backward
TEST_INTERVAL_SECS = 5


def test(mnist):
    with tf.Graph().as_default() as g:
        # x = tf.placeholder(tf.float32, [None, forward.INPUT_NODE])
        # y_ = tf.placeholder(tf.float32, [None, forward.OUTPUT_NODE])
        x = tf.placeholder(dtype=tf.float32, shape=[None, 784])
        y_ = tf.placeholder(dtype=tf.float32, shape=[None, minRNN.nclass])
        y_pre = minRNN.forward(x)["h"]
        # y = minRNN.forward(x)

        # ema = tf.train.ExponentialMovingAverage(backward.MOVING_AVERAGE_DECAY)
        # ema_restore = ema.variables_to_restore()
        saver = tf.train.Saver()

        # with tf.Graph().as_default() as g:
        #     x = tf.placeholder(tf.float32, [None,forward.INPUT_NODE])
        #     y_ = tf.placeholder(tf.float32, [None,forward.OUTPUT_NODE])
        #     y = forward.forward(x, None)
        #
        #     ema = tf.train.ExponentialMovingAverage(backward.MOVING_AVERAGE_DECAY)
        #     ema_restore = ema.variables_to_restore(
        #     tf.train.Saver(ema_restore)

        correct_predict = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))

        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(minRNN.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
                    print("test accuracy = %g" %  accuracy_score)
                else:
                    print('None checkpoint file found')
                    return
            time.sleep(TEST_INTERVAL_SECS)

def main():
    mnist = input_data.read_data_sets("./data/", one_hot=True)
    test(mnist)

if __name__ == '__main__':
    main()

猜你喜欢

转载自blog.csdn.net/qq_42105426/article/details/92766534