TensorFlow基础教程:搭建循环神经网络RNN

使用TensorFlow搭建循环神经网络

  • TensorFlow版本1.4.0
  • Python版本>3.5.0

循环神经网络RNN的原理可以参考这篇文章

本教程搭建的网络结构包含LSTM和一个全连接层

网络结构图如下:

输出—>LSTM—>全连接层—>输出

1.载入MNIST数据集

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

2.定义参数
RNN输入是一个时间序列,MNIST数据集中图片大小为28px*28px,可以将每一行的像素看成一个序列长度,那么时间步长就是28.

batch_size = 64
n_input = 784    # 图像大小
time_steps = 28  # 时间步长
input_size = 28  # 序列长度
num_classes = 10
rnn_size = 128   # rnn隐藏层大小
lr = 0.01

3.定义网络输出

x = tf.placeholder(tf.float32, shape=[None, n_input])
y = tf.placeholder(tf.float32, shape=[None, num_classes])

4.定义网络主结构

def rnn_model(x):
    # 将输入x变为[batch_size, time_steps, input_size]
    x = tf.reshape(x, shape=[-1, time_steps, input_size])
    # 构建rnn
    rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
    # 将输入送入rnn,得到输出与中间状态,输出shape为[batch_size, time_steps, rnn_size]
    outputs, states = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32)
    # 获取最后一个时刻的输出,输出shape为[batch_size, rnn_size]
    output = tf.transpose(outputs, [1,0,2])[-1]
    # 全连接层,最终输出大小为[batch_size, num_classes]
    fc_w = tf.Variable(tf.random_normal([rnn_size, num_classes]))
    fc_b = tf.Variable(tf.random_normal([num_classes]))
    return tf.matmul(output, fc_w) + fc_b

5.构建网络

logits = rnn_model(x)
prediction = tf.nn.softmax(logits)

6.定义损失函数与优化器

loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)

7.定义评价指标

correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

8.训练网络

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    total_batch = mnist.train.num_examples // batch_size
    for epoch in range(train_epochs):
        for batch in range(total_batch):

            batch_x, batch_y = mnist.train.next_batch(batch_size)
            sess.run(train_op, feed_dict={x:batch_x, y:batch_y})

            if batch % 200 == 0:
                loss, acc = sess.run([loss_op, accuracy], feed_dict={x:batch_x, y:batch_y})
                print("epoch {}, batch {}, loss {:.4f}, accuracy {:.3f}".format(epoch, batch, loss, acc))

    print("optimization finished")

    test_acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
    print("test acc", test_acc)

github源代码
https://github.com/gamersover/tensorflow_basic_tutorial/blob/master/basic_model/rnn_mnist.py

猜你喜欢

转载自blog.csdn.net/cetrol_chen/article/details/80357318
今日推荐