SimpleRNN 预测下个字母

看下SimpleRNN实现的预测下一字母。

from __future__ import print_function

from keras.layers import Dense, Activation
from keras.layers.recurrent import SimpleRNN
from keras.models import Sequential
from keras.utils.vis_utils import plot_model
import numpy as np

# using simpleRNN to generate next letter
class RNNSimple:
    def __init__(self, hidden_size=128, batch_size=128,
                 num_iter=24, num_epoch=1, num_pred=100, seq_len=10, step=1):
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.num_iter = num_iter
        self.num_epoch = num_epoch
        self.num_pred = num_pred
        self.seq_len = seq_len
        self.step = step

    def read_text(self, file_path):
        lines = []
        with open(file_path, 'rb') as f:
            for line in f:
                line = line.strip().lower()
                line = line.decode('ascii', 'ignore')
                if len(line) == 0:
                    continue
                lines.append(line)
        text = ' '.join(lines)
        return text

    def vectorize(self, text):
        # generate index
        chars = set([c for c in text])
        self.chars_count = len(chars)
        self.char2index = dict((c, i) for i, c in enumerate(chars))
        self.index2char = dict((i, c) for i, c in enumerate(chars))
        print(self.char2index)
        print(self.index2char)
        # generate input and label
        self.input_chars = []
        self.label_chars = []
        for i in range(0, len(text) - self.seq_len, self.step):
            self.input_chars.append(text[i: i+self.seq_len])
            self.label_chars.append(text[i+self.seq_len])
        print(self.input_chars)
        print(self.label_chars)
        # one-hot to vectorize input and label
        X = np.zeros((len(self.input_chars), self.seq_len, self.chars_count), dtype=np.bool)
        Y = np.zeros((len(self.input_chars), self.chars_count), dtype=np.bool)
        for i, input_char in enumerate(self.input_chars):
            for j, c in enumerate(input_char):
                X[i,j,self.char2index[c]] = 1 
            Y[i, self.char2index[self.label_chars[i]]] = 1 

        print(X.shape)
        print(Y.shape)
        return X, Y

    def train(self, X, Y):
        # build model
        model = Sequential()
        model.add(SimpleRNN(self.hidden_size, return_sequences=False,
                            input_shape=(self.seq_len, self.chars_count), unroll=True))
        model.add(Dense(self.chars_count))
        model.add(Activation('softmax'))
        model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
        # training and predict
        for iteration in range(self.num_iter):
            print('Iteration: %d'%iteration)
            model.fit(X, Y, batch_size=self.batch_size, epochs=self.num_epoch)
        return model

    def predict(self, model, test_chars):
        result = test_chars
        epoch_chars = test_chars
        for i in range(self.num_pred):
            vect_test = np.zeros((1, self.seq_len, self.chars_count))
            # label char index of vector as 1 which appear in test chars
            for i, ch in enumerate(epoch_chars):
                vect_test[0, i, self.char2index[ch]] = 1
            pred = model.predict(vect_test, verbose=0)[0]
            pred_char = self.index2char[np.argmax(pred)]
            result += pred_char
            epoch_chars = epoch_chars[1:] + pred_char
        return result

    def process(self):
        # 1. read text from file
        text = self.read_text('./test.txt')
        print(len(text))
        # 2. vectorize text
        X, Y = self.vectorize(text)
        # 3. train based on X, Y
        model = self.train(X, Y)
        # 4. try predict
        test_idx = np.random.randint(len(self.input_chars))
        test_chars = self.input_chars[test_idx]
        print('test seed is: %s'%test_chars)
        result = self.predict(model, test_chars)
        print('result is: %s'%result)

if __name__ == '__main__':
    rnn_simple = RNNSimple()
    rnn_simple.process()

输出结果

Iteration: 0
Epoch 1/1
1739/1739 [==============================] - 0s 180us/step - loss: 3.1235
Iteration: 1
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 2.8657
Iteration: 2
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.7682
Iteration: 3
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 2.6973
Iteration: 4
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.6078
Iteration: 5
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 2.5333
Iteration: 6
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.4714
Iteration: 7
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.4033
Iteration: 8
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.3416
Iteration: 9
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 2.2842
Iteration: 10
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.2356
Iteration: 11
Epoch 1/1
1739/1739 [==============================] - 0s 40us/step - loss: 2.1725
Iteration: 12
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.1179
Iteration: 13
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.0791
Iteration: 14
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.0257
Iteration: 15
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.9759
Iteration: 16
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 1.9279
Iteration: 17
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.8844
Iteration: 18
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 1.8449
Iteration: 19
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.7889
Iteration: 20
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.7589
Iteration: 21
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.7087
Iteration: 22
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 1.6699
Iteration: 23
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.6291
test seed is: rrive at t
result is: rrive at the input nodes area ion is the sem of the seruen eo an te de the  nm or es ae  eamhen niin  t ae ter

猜你喜欢

转载自blog.csdn.net/lwc5411117/article/details/83448751