Keras用LSTM字母顺序预测

这个只是一个LSTM Demo,输入一个字母预测下一个字母。可扩展成数据预测、代码生成、聊天机器人等复杂应用。

import keras as K
import matplotlib.pyplot as plt
import numpy as np
import math

alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'  # 26
# 字符与序号对应的字典
char_to_int = dict((c, i) for i, c in enumerate(alphabet))
int_to_char = dict((i, c) for i, c in enumerate(alphabet))
input_length = 1  # 输入序列长度,训练的步长
input_dim = 1  # 输入维度,输入参数数量
output_length = 1  # 输出序列长度,输出的步长
output_dim = 1  # 输出维度,输出参数数量
dataX = []
dataY = []
# 训练数据集
for i in range(0, len(alphabet) - input_length - output_length + 1, 2):
    seq_in = alphabet[i:i + input_length]
    seq_out = alphabet[i + input_length: i + input_length + output_length]
    dataX+=[[[char_to_int[char]]] for char in seq_in]
    dataY+=[[[char_to_int[char]]] for char in seq_out]
    print(seq_in, '->', seq_out)

dataX=np.array(dataX)
dataY=np.array(dataY)
print('dataX',dataX)
print('dataX',dataX.shape)
print('dataY',dataY)
print('dataY',dataY.shape)

# 定义神经网络结构
input_value = K.Input((input_length, input_dim),
                      name="input")  # 1个时间维度,1个输入参数
print('input_value',input_value.shape)
# LSTM stateful有状态的,一个batch后状态不清空。返回全部序列
x = K.layers.LSTM(32, activation=K.activations.relu,
                  return_sequences=True)(input_value)
print('x',x.shape)
output_value = K.layers.Dense(output_length)(x)
print('output_value',output_value.shape)

model = K.Model(inputs=input_value, outputs=output_value, name="test")
# 编译,定义损失函数
model.compile(K.optimizers.Adadelta(lr=0.5),
                    loss=[K.losses.mean_absolute_error],
                    metrics=[K.metrics.mean_absolute_error])
# 训练
history = model.fit(dataX, dataY, epochs=500, batch_size=1, validation_data=(dataX, dataY))


'''
    对数据绘图
'''
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.legend()
plt.show()

for i in range(0,len(alphabet)-1):
    out=model.predict(np.array([i]).reshape((-1,input_length, input_dim)))
    print(int_to_char[i],int_to_char[int(np.round(max(min(out,25),0),0))],out)
发布了28 篇原创文章 · 获赞 2 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/highlevels/article/details/89639958