版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
古诗生成—用LSTM
太懒了,数据集明天再传
整体流程
完整代码:
import numpy as np
#from collections import Counter
from tensorflow import keras
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.ops import summary_ops_v2
import time
import os
#处理数据集,目的:产生vocab_to_int,int_to_vocab 和 所有古诗转化的数字序列
def create_lookup_tables():
with open(r'newtxt.txt','r') as f:
text = f.read()
print(len(text))
vocab = sorted(set(text))
print(len(vocab))
vocab_to_int = {u:i for i,u in enumerate(vocab)}
#print(vocab_to_int)
int_to_vocab = {i:u for i,u in enumerate(vocab)}
int_text = np.array([vocab_to_int[word] for word in text if word != '\n'])
print(len(int_text))
return vocab_to_int,int_to_vocab,int_text
vocab_to_int,int_to_vocab,int_text = create_lookup_tables()
#这个是获取下一batch内容的函数,要好好看看
def get_batches(int_text, batch_size, seq_length):
batchCnt = len(int_text) // (batch_size * seq_length)
#y取x的下一个
int_text_inputs = int_text[:batchCnt * (batch_size * seq_length)]
int_text_targets = int_text[1:batchCnt * (batch_size * seq_length)+1]
result_list = []
x = np.array(int_text_inputs).reshape(1, batch_size, -1)
y = np.array(int_text_targets).reshape(1, batch_size, -1)
'''
split(ary, indices_or_sections, axis=0)
把一个数组从左到右按顺序切分
参数:
ary: 要切分的数组
indices_or_sections: 如果是一个整数,就用该数平均切分,如果是一个数组,
为沿轴切分的位置(左开右闭)
axis: 沿着哪个维度进行切向,默认为0,横向切分。为1时,纵向切分
参考网站:https://blog.csdn.net/lthirdonel/article/details/88690923
'''
x_new = np.dsplit(x, batchCnt)
y_new = np.dsplit(y, batchCnt)
for ii in range(batchCnt):
x_list = []
x_list.append(x_new[ii][0])
x_list.append(y_new[ii][0])
result_list.append(x_list)
return np.array(result_list)
###########
#汉字的个数
vocab_size = len(int_to_vocab)
# 批次大小
batch_size = 32 # 64
# RNN的大小(隐藏节点的维度) 即隐藏状态(h)的的维度
rnn_size = 1000
# 嵌入层的维度
embed_dim = 256 # 这里做了调整,跟彩票预测的也不同了
# 序列的长度
seq_length = 15 # 有15个rnn细胞
save_dir = './save'
###########
MODEL_DIR = "./poetry_models"
train_batches = get_batches(int_text, batch_size, seq_length)
losses = {'train': [], 'test': []}
##########
#构建网络
class poetry_network(object):
def __init__(self,batch_size=32):
self.model = keras.Sequential([
keras.layers.Embedding(vocab_size,embed_dim,
batch_input_shape=[batch_size,None]),########
keras.layers.LSTM(rnn_size,return_sequences=True,
stateful=True,recurrent_initializer='glorot_uniform'),####
keras.layers.Dense(vocab_size)
])
self.model.summary()#把model展示出来,去掉也行
self.optimizer = keras.optimizers.Adam()#定义优化方式
#用于计算损失
#from_logits=True根据预测结果(logits)进行计算损失
self.ComputeLoss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
#要存放训练出来的模型,先生成文件
if tf.io.gfile.exists(MODEL_DIR):
pass
else:
tf.io.gfile.makedirs(MODEL_DIR)
#这个有问题,好像没法生成该文件,我在自己创建了
train_dir = os.path.join(MODEL_DIR,'summaries','train')####???
#用于记录日志,flush_millis=10000好像是更新时间,单位毫秒,具体还不知道什么用
self.train_summary_writer = summary_ops_v2.create_file_writer(train_dir,flush_millis=10000)
#生成结点和保存模型
#prefix:前缀
checkpoint_dir = os.path.join(MODEL_DIR,"checkpoints")
self.checkpoint_prefix = os.path.join(checkpoint_dir,'ckpt')
#主要步骤
self.checkpoint = tf.train.Checkpoint(model=self.model,optimizer=self.optimizer)
##加载已经存在的模型
self.checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
#一次训练步,计算损失和预测结果
def train_step(self,x,y):
with tf.GradientTape() as tape:
logits = self.model(x,training=True)
loss = self.ComputeLoss(y,logits)
grads = tape.gradient(loss,self.model.trainable_variables)
self.optimizer.apply_gradients(zip(grads,self.model.trainable_variables))
return loss,logits
#训练函数
def training(self,epochs=1,log_freq=50):####
batchCnt = len(int_text) // (batch_size * seq_length)
print("batchCnt的值是",batchCnt)
#训练的循环
for i in range(epochs):
train_start = time.time()
with self.train_summary_writer.as_default():
start = time.time()
#定义平均指标用来计算平均损失
avg_loss = keras.metrics.Mean('loss',dtype = tf.float32)#####
for batch_i,(x,y) in enumerate(train_batches):
#单步训练
loss,logits = self.train_step(x,y)
avg_loss(loss)
losses['train'].append(loss)####???
#训练达到一定的次数,打印日志
if tf.equal(self.optimizer.iterations % log_freq,0):
summary_ops_v2.scalar('loss',avg_loss.result(),
step=self.optimizer.iterations)
rate = log_freq / (time.time() - start)
print('Step #{}\tLoss:{:0.6f}({}steps/sec)'.format(self.optimizer.iterations.numpy(),loss,rate))
avg_loss.reset_states()
start = time.time()
self.checkpoint.save(self.checkpoint_prefix)
print('模型已经保存(迭代一次保存一次)')
#开始训练
net = poetry_network()#创建一个类的实例
net.training(2)#输入的是迭代次数
#加载保存的模型,输入的1代表batch_size
restore_net=poetry_network(1)
restore_net.model.build(tf.TensorShape([1, None]))#这一句不太懂,不写又不敢
def gen_poetry(prime_word='白', top_n=5, rule=7, sentence_lines=4, hidden_head=None):
gen_length = sentence_lines * (rule + 1) - len(prime_word)
gen_sentences = [prime_word] if hidden_head==None else [hidden_head[0]]
temperature = 1.0
dyn_input = [vocab_to_int[s] for s in prime_word]
dyn_input = tf.expand_dims(dyn_input, 0)
dyn_seq_length = len(dyn_input[0])
restore_net.model.reset_states()
index=len(prime_word) if hidden_head==None else 1
for n in range(gen_length):
index += 1
predictions = restore_net.model(np.array(dyn_input))
predictions = tf.squeeze(predictions, 0)
if index!=0 and (index % (rule+1)) == 0:
if ((index / (rule+1)) + 1) % 2 == 0:
predicted_id=vocab_to_int[',']
else:
predicted_id=vocab_to_int['。']
else:
if hidden_head != None and (index-1)%(rule+1)==0 and (index-1)//(rule+1) < len(hidden_head):
predicted_id=vocab_to_int[hidden_head[(index-1)//(rule+1)]]
else:
while True:
predictions = predictions / temperature
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()
# p = np.squeeze(predictions[-1].numpy())
# p[np.argsort(p)[:-top_n]] = 0
# p = p / np.sum(p)
# c = np.random.choice(vocab_size, 1, p=p)[0]
# predicted_id=c
if(predicted_id != vocab_to_int[','] and predicted_id != vocab_to_int['。'] ):
break
# using a multinomial distribution to predict the word returned by the model
# predictions = predictions / temperature
# predicted_id = tf.multinomial(predictions, num_samples=1)[-1,0].numpy()
dyn_input = tf.expand_dims([predicted_id], 0)
gen_sentences.append(int_to_vocab[predicted_id])
poetry_script = ' '.join(gen_sentences)
poetry_script = poetry_script.replace('\n ', '\n')
poetry_script = poetry_script.replace('( ', '(')
return poetry_script
result = gen_poetry(prime_word='少杰', top_n=10, rule=4, sentence_lines=4,hidden_head=None)
print(result)