本人是从网上教程视频中看到唐诗生成教学视频,自己按照上面的案例,手动敲了一下代码,如果仅仅只是手动按照上面给出的内容敲,其实学到的内容是有限的,所以我当时萌生出一个想法,如果是藏头诗呢,该怎么做呢,其实网上也有人给出来了,我现在里面的源码还是不太好用,当然我基于自己先前的代码进行的修改完成的,
代码下载。下面我说一下思路:
1.数据预处理
在这里我们读取文件,当然这里面的训练集是全是唐诗。然后做成向量。代码如下:
start_token = 'G'
end_token = 'E'
def process_poems(file_name):
poems = []
with open(file_name,"r",encoding='utf-8') as f:
for line in f.readlines():
try:
title,content = line.strip().split(':')
content = content.replace(' ','')
if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
start_token in content or end_token in content:
continue
if len(content) < 5 or len(content) > 79:
continue
content = start_token + content +end_token
poems.append(content)
except ValueError as e:
pass
#按诗的字数排序
poems = sorted(poems,key=lambda l:len(line))
#统计每个字出现的次数
all_words = []
for poem in poems:
all_words += [word for word in poem]
#这里根据包含了每个字对应的频率
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(),key=lambda x: -x[1])
words,_ = zip(*count_pairs)
# print(words)
#取前多少个字
words = words[:len(words)] + (' ',)
# print(words)
#每个字映射为一个数字ID
word_int_map = dict(zip(words,range(len(words))))
poems_vector = [list(map(lambda word:word_int_map.get(word,len(words)),poem))for poem in poems]
return poems_vector, word_int_map, words
2.生成Batch
在这里我们需要注意的是每首诗的长度不同,我们需要将其补全。然后x_data很容易表示出来了,那么y_data就是x_data向后移一位
def generate_batch(batch_size,poems_vec,word_to_int):
#每次取64首诗进行训练
n_chunk = len(poems_vec) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i * batch_size
end_index = start_index + batch_size
batches = poems_vec[start_index:end_index]
#找到这个batch的所有poem中最长的poem长度
length = max(map(len,batches))
# print(length)
#填充一个这么大小的batch,空的地方放空格对应的index标号
x_data = np.full((batch_size,length),word_to_int[' '],np.int32)
for row in range(batch_size):
#每一行就是一首诗,在原本的长度上把诗还原上去
# print(batches[row])
x_data[row, :len(batches[row])] = batches[row]
y_data = np.copy(x_data)
#y的话就是x向左边也就是前面移动一个
y_data[:,:-1] = x_data[:,1:]
x_batches.append(x_data)
y_batches.append(y_data)
return x_batches,y_batches
3.建立LSTM模型
隐藏层大小为128,然后LSTM的网络结构数为两层,batch_size=64
def rnn_model(model,input_data,output_data,vocab_size,rnn_size=128,num_layers=2,batch_size=64,learning_rate=0.01):
end_points = {}
if model == 'rnn':
cell_fun = tf.contrib.rnn.BasicRNNCell
elif model == 'gru':
cell_fun = tf.contrib.rnn.GRUCell
elif model == 'lstm':
cell_fun = tf.contrib.rnn.BasicLSTMCell
cell = cell_fun(rnn_size,state_is_tuple=True)
cell = tf.contrib.rnn.MultiRNNCell([cell]*num_layers,state_is_tuple=True)
if output_data is not None:
initial_state = cell.zero_state(batch_size,tf.float32)
else:
initial_state = cell.zero_state(1,tf.float32)
with tf.device("/cpu:0"):
embedding = tf.get_variable('embedding',initializer=tf.random_uniform([vocab_size+1,rnn_size],-1.0,1.0))
inputs = tf.nn.embedding_lookup(embedding,input_data)
outputs,last_state = tf.nn.dynamic_rnn(cell,inputs,initial_state=initial_state)
output = tf.reshape(outputs,[-1,rnn_size])
weights = tf.Variable(tf.truncated_normal([rnn_size,vocab_size+1]))
bias = tf.Variable(tf.zeros(shape=vocab_size+1))
logits = tf.nn.bias_add(tf.matmul(output,weights),bias=bias)
if output_data is not None:
#output_data must be one-hot encode
labels = tf.one_hot(tf.reshape(output_data,[-1]),depth=vocab_size+1)
#should be[?,vocab_size+1]
loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels,logits=logits)
#loss shape should be [?,vocab_size+!]
total_loss = tf.reduce_mean(loss)
train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
end_points['initial_state'] = initial_state
end_points['output'] = output
end_points['train_op'] = train_op
end_points['total_loss'] = total_loss
end_points['loss']=loss
end_points['last_state']=last_state
else:
prediction = tf.nn.softmax(logits)
end_points['initial_state'] = initial_state
end_points['last_state'] = last_state
end_points['prediction'] = prediction
return end_points
4.训练模型
我们需要对模型进行保存,比如每6个epoch保存一下,以方便我们后面测试用,如果这里不保存模型,每6次保存下来的模型是不同的,生成的唐诗风格也不同,其实也意味着最后一个保存的模型不一定是最好的。
def run_training():
if not os.path.exists(os.path.dirname(FLAGS.checkpoints_dir)):
os.mkdir(os.path.dirname(FLAGS.checkpoints_dir))
if not os.path.exists(FLAGS.checkpoints_dir):
os.mkdir(FLAGS.checkpoints_dir)
poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
# sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
# sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
sess.run(init_op)
start_epoch = 0
checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
if checkpoint:
saver.restore(sess, checkpoint)
print("[INFO] restore from the checkpoint {0}".format(checkpoint))
start_epoch += int(checkpoint.split('-')[-1])
print('[INFO] start training...')
try:
for epoch in range(start_epoch, FLAGS.epochs):
n = 0
n_chunk = len(poems_vector) // FLAGS.batch_size
for batch in range(n_chunk):
loss, _, _ = sess.run([
end_points['total_loss'],
end_points['last_state'],
end_points['train_op']
], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
n += 1
print('[INFO] Epoch: %d , batch: %d , training loss: %.6f' % (epoch, batch, loss))
if epoch % 6 == 0:
saver.save(sess, './model/', global_step=epoch)
#saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)
except KeyboardInterrupt:
print('[INFO] Interrupt manually, try saving checkpoint for now...')
saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)
print('[INFO] Last epoch were saved, next time will start from epoch {}.'.format(epoch))
5.测试模型
这里看起来简单,其实没那么容易,我们需要注意藏头诗,自然是利用每一个字进行生成一句诗,我也是根据之前看视频写的代码,主要改了这一部分,然后变成藏头诗
def gen_poem(begin_words,type):
if type !=5 and type != 7:
print('The second para has to be 5 or 7!')
return
batch_size = 1
print('[INFO] loading corpus from %s' % FLAGS.file_path)
poems_vector, word_int_map, vocabularies = process_poems(FLAGS.file_path)
input_data = tf.placeholder(tf.int32, [batch_size, None])
end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
#checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
checkpoint = tf.train.latest_checkpoint('./model/')
#saver.restore(sess, checkpoint)
saver.restore(sess, './model/-36')
poem = ''
for head in begin_words:
flag = True
while flag:
# state_ = sess.run(cell.zero_state(1, tf.float32))
x = np.array([list(map(word_int_map.get, start_token))])
# print('x=',x)
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x})
# if begin_word:
# word = begin_word
# else:
# word = to_word(predict, vocabularies)
sentence = head
x = np.zeros((1, 1))
x[0, 0] = word_int_map[sentence]
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x, end_points['initial_state']: last_state})
word = to_word(predict, vocabularies)
sentence += word
print(sentence)
while word != u'。':
print ('running')
# poem += word
# print(word)
x = np.zeros((1, 1))
x[0, 0] = word_int_map[word]
# print('x2=', x)
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x, end_points['initial_state']: last_state})
word = to_word(predict, vocabularies)
sentence += word
print(sentence,len(sentence))
if len(sentence) == 2 + 2 * type:
sentence += '\n'
poem += sentence
flag = False
# print('word=', word)
# word = words[np.argmax(probs_)]
return poem
结果如下图所示:
执遮白云法,山川深梦里。
子宓心越思,山欲下童儿。
之子自不索,朱师能识诗。
手计似家路,知与休火真。
与君归去第,西令灵六千。
子黏若藏文,葩川不上文。
偕篮惊此鹤,犹真似被鸡。
老邑人间衲,一名养论化。