本篇博客我们将学习使用TensorFlow搭建一个循环神经网络(RNN)模型,并用它来训练MNIST数据集。RNN在自然语言处理领域的以下几个方向已经取得了非常大的成功:
- 机器翻译
- 语音识别
- 图像描述生成(将RNN与CNN相互结合)
- 语言模型与文本生成,即利用生成模型预测下一个单词的可能性。
接下来我们讲解如何使用RNN完成MNIST数据集的分类问题:
(1)加载数据并设置超参数(学习率、训练次数、每轮训练数据大小)
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
lr=0.001
training_iters=100000
batch_size=128
为了使用RNN来分类图像,我们将每张图像的行看做一个像素序列,MNIST数据集中的图像是28*28,所以每张图像存在28行28个元素的序列。因此在RNN模型中每一步输入的序列长度为28,输入的步数为28步。
(2)定义RNN模型的参数
#神经网络参数
n_inputs=28
n_steps=28
n_hidden_units=128
n_classes=10
#输入数据的占位符
x=tf.placeholder(tf.float32,[None,n_steps,n_inputs])
y=tf.placeholder(tf.float32,[None,n_classes])
#定义权重
weights={'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),
'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes]))}
biases={'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),
'out':tf.Variable(tf.constant(0.1,shape=[n_classes,]))}
(3)定义RNN模型
def RNN(X,weights,biases):
#将输入Xreshape成(128batch×28steps,28inputs)
X=tf.reshape(X,[-1,n_inputs])
#x_in=(128batch*28steps,128hidden)
X_in=tf.matmul(X,weights['in'])+biases['in']
#x_in=(128batch,28steps,128hidden)
X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])
#神经网络单元采用LSTM:basic LSTM Cell
lstm_cell=tf.contrib.rnn.BasicLSTMCell(n_hidden_units,forget_bias=1.0,
state_is_tuple=True)
#初始化为零值
init_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)
#dynamic_rnn接收张量(steps,batch,inputs)作为x_in
outputs,final_state=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=init_state,
time_major=False)
results=tf.matmul(final_state[1],weights['out'])+biases['out']
return (results)
(4) 定义损失函数和优化器,优化器采用AdamOptimizer:
pred=RNN(x,weights,biases)
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
train_op=tf.train.AdamOptimizer(lr).minimize(cost)
(5)定义模型预测结果和准确率计算方法:
correct_pre=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_pre,tf.float32))
(6)训练数据和评估模型,在一个会话中启动图,开始训练每20次输出1次准确率大小:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step=0
while step*batch_size<training_iters:
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])
sess.run([train_op],feed_dict={x:batch_xs,
y:batch_ys,})
if step%20==0:
print(sess.run(accuracy,feed_dict={x:batch_xs,
y:batch_ys,}))
step+=1
程序在Python3下运行(如果程序报错删除中文注释),运行结果如下: