The Vanishing Gradient Problem and LSTM Code Snippets

1. What is gradient disappearance and gradient explosion?
Taking BP backpropagation as an example, the chain derivation rule, the weights and biases of the previous hidden layer depend on the latter layer. When the activation function is activated write picture description here, the more layers are, the more The smaller the derivative result is, the less the weights and biases of the previous layers are different from the initial values, and the parameter update is slow, which is the so-called gradient disappearance situation.
In the same way, gradient explosion: chain derivation results in particularly large results, and parameter updates are very fast. It can be seen that the gradient disappearance and gradient explosion are mainly caused by the network being too deep and the weight update being unstable. The essential reason is the multiplication effect in backpropagation.
Improvement methods:
(1) Select an appropriate activation function, such as: replace the Sigmoid function with the Relu function
(2) For the gradient disappearance problem in the RNN network, you can choose the LSTM structure instead; the
specific mathematical process explains the gradient disappearance and gradient explosion, refer to Blog: https://ziyubiti.github.io/2016/11/06/gradvanish/
2. LSTM (Long Short Term Memory) structure
write picture description here
Three thresholds:
(1) input gate selects the input signal
(2) forget gate selection Forget which signals;
(3) The output gate selects the signal output;
the three inputs of these three thresholds are the same, and the activation function generally selects tanh();
write picture description here
among them, each result is represented in the following figure:
write picture description here
The specific implementation code is as follows:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)

#输入图片:28*28
n_inputs=28 #一行28个数据,输入神经元的个数
max_time=28 #共28行
lstm_size=100 #隐藏层100个基础单元
n_classes=10 #输出神经元的个数
batch_size=50
n_batch=mnist.train.num_examples//batch_size

x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])

weights=tf.Variable(tf.truncated_normal([lstm_size,n_classes],stddev=0.1))
biases=tf.Variable(tf.constant(0.1,shape=[n_classes]))

def RNN(X,weights,biases):
    #inputs=[batch_size,max_time,n_inputs]
    inputs=tf.reshape(X,[-1,max_time,n_inputs])

    #定义LSTM基本cell
    lstm_cell=tf.contrib.rnn.BasicLSTMCell(lstm_size)
    outputs,final_state=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
    results=tf.nn.softmax(tf.matmul(final_state[1],weights)+biases)
    return results
    #final_state[0]是cell state
    #final_state[1]是hidden_state,最后时间序列一次输出的信号,28个序列,产生1个结果
    #outputs记录每一次时间序列的输出结果,28个序列,产生28个结果,所以output第27个序列的输出结果和final_state[1]是一样的

#计算RNN的结果
prediction=RNN(x,weights,biases)
#计算损失函数
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
optimizer=tf.train.AdamOptimizer(1e-4)
train_step=optimizer.minimize(cross_entropy)

#结果存放在一个bool类型列表中
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#测试准确率
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(6):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})

        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})  
        print("Iter "+str(epoch)+" Test Accuracy= "+str(acc))


operation result:

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Iter 0 Test Accuracy= 0.7476
Iter 1 Test Accuracy= 0.8575
Iter 2 Test Accuracy= 0.8893
Iter 3 Test Accuracy= 0.9181
Iter 4 Test Accuracy= 0.927
Iter 5 Test Accuracy= 0.9359

Guess you like

Origin http://10.200.1.11:23101/article/api/json?id=326773258&siteId=291194637