看下SimpleRNN实现的预测下一字母。
from __future__ import print_function
from keras.layers import Dense, Activation
from keras.layers.recurrent import SimpleRNN
from keras.models import Sequential
from keras.utils.vis_utils import plot_model
import numpy as np
# using simpleRNN to generate next letter
class RNNSimple:
def __init__(self, hidden_size=128, batch_size=128,
num_iter=24, num_epoch=1, num_pred=100, seq_len=10, step=1):
self.hidden_size = hidden_size
self.batch_size = batch_size
self.num_iter = num_iter
self.num_epoch = num_epoch
self.num_pred = num_pred
self.seq_len = seq_len
self.step = step
def read_text(self, file_path):
lines = []
with open(file_path, 'rb') as f:
for line in f:
line = line.strip().lower()
line = line.decode('ascii', 'ignore')
if len(line) == 0:
continue
lines.append(line)
text = ' '.join(lines)
return text
def vectorize(self, text):
# generate index
chars = set([c for c in text])
self.chars_count = len(chars)
self.char2index = dict((c, i) for i, c in enumerate(chars))
self.index2char = dict((i, c) for i, c in enumerate(chars))
print(self.char2index)
print(self.index2char)
# generate input and label
self.input_chars = []
self.label_chars = []
for i in range(0, len(text) - self.seq_len, self.step):
self.input_chars.append(text[i: i+self.seq_len])
self.label_chars.append(text[i+self.seq_len])
print(self.input_chars)
print(self.label_chars)
# one-hot to vectorize input and label
X = np.zeros((len(self.input_chars), self.seq_len, self.chars_count), dtype=np.bool)
Y = np.zeros((len(self.input_chars), self.chars_count), dtype=np.bool)
for i, input_char in enumerate(self.input_chars):
for j, c in enumerate(input_char):
X[i,j,self.char2index[c]] = 1
Y[i, self.char2index[self.label_chars[i]]] = 1
print(X.shape)
print(Y.shape)
return X, Y
def train(self, X, Y):
# build model
model = Sequential()
model.add(SimpleRNN(self.hidden_size, return_sequences=False,
input_shape=(self.seq_len, self.chars_count), unroll=True))
model.add(Dense(self.chars_count))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# training and predict
for iteration in range(self.num_iter):
print('Iteration: %d'%iteration)
model.fit(X, Y, batch_size=self.batch_size, epochs=self.num_epoch)
return model
def predict(self, model, test_chars):
result = test_chars
epoch_chars = test_chars
for i in range(self.num_pred):
vect_test = np.zeros((1, self.seq_len, self.chars_count))
# label char index of vector as 1 which appear in test chars
for i, ch in enumerate(epoch_chars):
vect_test[0, i, self.char2index[ch]] = 1
pred = model.predict(vect_test, verbose=0)[0]
pred_char = self.index2char[np.argmax(pred)]
result += pred_char
epoch_chars = epoch_chars[1:] + pred_char
return result
def process(self):
# 1. read text from file
text = self.read_text('./test.txt')
print(len(text))
# 2. vectorize text
X, Y = self.vectorize(text)
# 3. train based on X, Y
model = self.train(X, Y)
# 4. try predict
test_idx = np.random.randint(len(self.input_chars))
test_chars = self.input_chars[test_idx]
print('test seed is: %s'%test_chars)
result = self.predict(model, test_chars)
print('result is: %s'%result)
if __name__ == '__main__':
rnn_simple = RNNSimple()
rnn_simple.process()
输出结果
Iteration: 0
Epoch 1/1
1739/1739 [==============================] - 0s 180us/step - loss: 3.1235
Iteration: 1
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 2.8657
Iteration: 2
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.7682
Iteration: 3
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 2.6973
Iteration: 4
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.6078
Iteration: 5
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 2.5333
Iteration: 6
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.4714
Iteration: 7
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.4033
Iteration: 8
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.3416
Iteration: 9
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 2.2842
Iteration: 10
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.2356
Iteration: 11
Epoch 1/1
1739/1739 [==============================] - 0s 40us/step - loss: 2.1725
Iteration: 12
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.1179
Iteration: 13
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.0791
Iteration: 14
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 2.0257
Iteration: 15
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.9759
Iteration: 16
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 1.9279
Iteration: 17
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.8844
Iteration: 18
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 1.8449
Iteration: 19
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.7889
Iteration: 20
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.7589
Iteration: 21
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.7087
Iteration: 22
Epoch 1/1
1739/1739 [==============================] - 0s 42us/step - loss: 1.6699
Iteration: 23
Epoch 1/1
1739/1739 [==============================] - 0s 41us/step - loss: 1.6291
test seed is: rrive at t
result is: rrive at the input nodes area ion is the sem of the seruen eo an te de the nm or es ae eamhen niin t ae ter