完全なコード+ githubのを設定しRNN tensorflow-BasicLSTMCell + mnist手書き文字分類データ

まず、原則として一部

RNNの詳細については、ボーエンは参照
リカレントニューラルネットワークは、RNN + tensorflow原則実装詳細
LSTM上を:
私たちは、RNNは任意の時間の前に入力を含む知っているが、これは、勾配や勾配爆発の消失の原因となりますが、良い学習機能することはできません従ってLSTM使用
([0-1]を介して進SIGMOD乗算される)入力(X、H)の制御(1)、
(2)制御「蓄積」状態忘れ度を必要とする前に、
(3)現在の時間を制御しますさらに、出力の必要性。
より柔軟な制御を達成するために。
:ダイアグラムは次のようである
ここに画像を挿入説明
、以下の二つのブログ記事を参照してください非常に詳細に伝えます。
NUM_UNITSパラメータのBasicLSTMCellを説明
明確である説明の上に示されたプロセス変数を取得する方法についてtf.nn.rnn_cell.BasicLSTMCellクラスの次の記事を。
深さの研究ノート2:ニューラルネットワークの入力と出力LSTMの理解

第二に、コードの詳細

図1に示すように、コード分析

二つのファイルの合計は
1であるRNN_LSTM_Classfication.py
1 simple_RNN.py

RNN_LSTM_Classfication.py

ヘッダーファイル、およびデータのロード、データのロードとは、ボーエンで説明したいくつかのエラーtensorflow実行コンボリューションニューラルネットワーク

from tensorflow import keras, Session, transpose, global_variables_initializer
from modular.simple_RNN import simple_RNN
from modular.compute_accuracy import compute_accuracy
from modular.random_choose import corresponding_choose
(train_x_image, train_y), (test_x_image, test_y) = keras.datasets.mnist.load_data(path='/home/xiaoshumiao/.keras/datasets/mnist.npz')
train_y = keras.utils.to_categorical(train_y)
test_y=keras.utils.to_categorical(test_y)

パラメータ設定

epochs = 1000
n_classes = 10
batch_size = 200#number
chunk_size = 28
n_chunk = 28
rnn_size = 128#the lenth of a hidden_neural_layer or the number of hidden_neural
learning_rate = 0.001

自分の定義のクラスのインスタンスRNN

rnn = simple_RNN(chunk_size, n_chunk, rnn_size, batch_size, n_classes, learning_rate)

セッサ()

with Session() as sess:
    sess.run(global_variables_initializer())
    for i in range(epochs):
        train_data = corresponding_choose(train_x_image, batch_size, m=0)
        train_x_betch = train_data.row_2(train_x_image) / 255.
        train_y_betch = train_data.row(train_y)
        sess.run(rnn.train,feed_dict={rnn.X:train_x_betch,rnn.y:train_y_betch})
        if i % 20 ==0:
            test_data = corresponding_choose(test_x_image, 200, m=0)
            test_x_betch = test_data.row_2(test_x_image) / 255.
            test_y_betch = test_data.row(test_y)
            b = sess.run(rnn.result, feed_dict={rnn.X: test_x_betch})
            c = sess.run(compute_accuracy(b, transpose(test_y_betch)))
            print(c)

simple_RNN.py

ヘッダファイルのインポート

from tensorflow import placeholder,float32,transpose, nn, reduce_mean, multiply, log, reduce_sum, train
from add_layer import add_layer
from tensorflow.python.ops.rnn import dynamic_rnn

データ入力

class simple_RNN(object):
    def __init__(self,chunk_size,n_chunk,hidden_chunk_size, batch_size,  n_class, learning_rate):
        self.X = placeholder(float32,[None,n_chunk,chunk_size])#200,28,28
        self.y = placeholder(float32,[None, n_class])

セル定義
初期状態値(二つのリンク上の構造を参照)を
構築RNN図最後の時間ステップを計算し、出力を得るために。

self.LSTM_cell = nn.rnn_cell.BasicLSTMCell(hidden_chunk_size, forget_bias=1.0, state_is_tuple=True)
self.init_state = self.LSTM_cell.zero_state(batch_size, float32)
self.output, self.states = dynamic_rnn(self.LSTM_cell, self.X, initial_state=self.init_state, dtype=float32)

一つの出力コネクタに完全に接続された層を得た後、10次元の出力を得ました。

self.result = add_layer(transpose(self.states[1]), hidden_chunk_size, n_class, activation_function=nn.softmax)

完全に接続されたネットワークトレーニング全出力接続層をトレーニングするための方法

self.loss = reduce_mean(-reduce_sum(multiply(transpose(self.y),log(self.result)),reduction_indices=[0]))
self.train = train.AdamOptimizer(learning_rate).minimize(self.loss)

2、問題解決

(1)、
最初の質問は、長い時間のための問題に悩まされ、私はMoのトラブルに応じて彼らの教訓を学ぶRNNパイソンに応じてMoが疲れていたので、私たちは、最初に手動であるため、その後細胞内に128次元のベクトル層を計算します。しかし、私はボーエンはNUM_UNITSは、神経の数隠された層で定義された元tf.nn.rnn_cell.BasicLSTMCell機能(によるこの層の出力ベクトルの大きさ)、で、発見し、地図dynamic_rnnを計算RNNを構築する際に、私たちが必要でした入力は、生データです。またLSTM入力が一瞬にニューロンと出力生データを隠されている、2つのブログ上記のリンクを参照してください。要約するので、我々は結果を手動で隠された層ニューロンを計算達成しません。
同時に、それはtf.nn.rnn_cell.BasicLSTMCellで見つけることができる
ここに画像を挿入説明
ので、入力データを処理する必要はありません。

(2)、

test_data = corresponding_choose(test_x_image, 200, m=0)

それは200を設定し、または再度変更することができますので、試験サンプルBATCH_SIZEとトレーニングサンプルの冒頭では、満たしていません。

(3)は、
変数を初期化すると正直に言うと、何か他のものは使用しないでください
)(global_variables_initializer

第三に、githubの+実験結果

:次のような結果を達成するために、
ここに画像を挿入説明
全体的に、結果はまだかなりのを。

完全なコード(他の書き込み、自分のライブラリ呼び出し)、および他のプログラムのコメントは、あなたがすべて一緒にgithubの上、場所に注意を払うに必要な、必要が見ることができます。

https://github.com/wangjunhe8127/tensorflow-BasicLSTMCell-mnist

力はあなたと一緒かもしれません!

公開された32元の記事 ウォン称賛7 ビュー2158

おすすめ

転載: blog.csdn.net/def_init_myself/article/details/105372941
おすすめ