LSTM生成藏头诗

本人是从网上教程视频中看到唐诗生成教学视频,自己按照上面的案例,手动敲了一下代码,如果仅仅只是手动按照上面给出的内容敲,其实学到的内容是有限的,所以我当时萌生出一个想法,如果是藏头诗呢,该怎么做呢,其实网上也有人给出来了,我现在里面的源码还是不太好用,当然我基于自己先前的代码进行的修改完成的, 代码下载。下面我说一下思路:

1.数据预处理

在这里我们读取文件,当然这里面的训练集是全是唐诗。然后做成向量。代码如下:
start_token = 'G'
end_token = 'E'

def process_poems(file_name):
	poems = []
	with open(file_name,"r",encoding='utf-8') as f:
		for line in f.readlines():
			try:
				title,content = line.strip().split(':')
				content = content.replace(' ','')
				if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
					start_token in content or end_token in content:
					continue
				if len(content) < 5 or len(content) > 79:
					continue
				content = start_token + content +end_token
				poems.append(content)
			except ValueError as e:
				pass
	#按诗的字数排序
	poems = sorted(poems,key=lambda l:len(line))
	#统计每个字出现的次数
	all_words = []
	for poem in poems:
		all_words += [word for word in poem]
	#这里根据包含了每个字对应的频率
	counter = collections.Counter(all_words)
	count_pairs = sorted(counter.items(),key=lambda x: -x[1])
	words,_ = zip(*count_pairs)
	# print(words)
	#取前多少个字
	words = words[:len(words)] + (' ',)
	# print(words)
	#每个字映射为一个数字ID
	word_int_map = dict(zip(words,range(len(words))))
	poems_vector = [list(map(lambda word:word_int_map.get(word,len(words)),poem))for poem in poems]

	return poems_vector, word_int_map, words

2.生成Batch

在这里我们需要注意的是每首诗的长度不同,我们需要将其补全。然后x_data很容易表示出来了,那么y_data就是x_data向后移一位
def generate_batch(batch_size,poems_vec,word_to_int):
	#每次取64首诗进行训练
	n_chunk = len(poems_vec) // batch_size
	x_batches = []
	y_batches = []

	for i in range(n_chunk):
		start_index = i * batch_size
		end_index = start_index + batch_size

		batches = poems_vec[start_index:end_index]
		#找到这个batch的所有poem中最长的poem长度
		length = max(map(len,batches))
		# print(length)
		#填充一个这么大小的batch,空的地方放空格对应的index标号
		x_data = np.full((batch_size,length),word_to_int[' '],np.int32)
		for row in range(batch_size):
			#每一行就是一首诗,在原本的长度上把诗还原上去
			# print(batches[row])
			x_data[row, :len(batches[row])] = batches[row]
		y_data = np.copy(x_data)
		#y的话就是x向左边也就是前面移动一个
		y_data[:,:-1] = x_data[:,1:]
		x_batches.append(x_data)
		y_batches.append(y_data)
	return x_batches,y_batches

3.建立LSTM模型

隐藏层大小为128,然后LSTM的网络结构数为两层,batch_size=64
def rnn_model(model,input_data,output_data,vocab_size,rnn_size=128,num_layers=2,batch_size=64,learning_rate=0.01):
	end_points = {}
	if model == 'rnn':
		cell_fun = tf.contrib.rnn.BasicRNNCell
	elif model == 'gru':
		cell_fun = tf.contrib.rnn.GRUCell
	elif model == 'lstm':
		cell_fun = tf.contrib.rnn.BasicLSTMCell

	cell = cell_fun(rnn_size,state_is_tuple=True)
	cell = tf.contrib.rnn.MultiRNNCell([cell]*num_layers,state_is_tuple=True)

	if output_data is not None:
		initial_state = cell.zero_state(batch_size,tf.float32)
	else:
		initial_state = cell.zero_state(1,tf.float32)

	with tf.device("/cpu:0"):
		embedding = tf.get_variable('embedding',initializer=tf.random_uniform([vocab_size+1,rnn_size],-1.0,1.0))
		inputs = tf.nn.embedding_lookup(embedding,input_data)

	outputs,last_state = tf.nn.dynamic_rnn(cell,inputs,initial_state=initial_state)
	output = tf.reshape(outputs,[-1,rnn_size])

	weights = tf.Variable(tf.truncated_normal([rnn_size,vocab_size+1]))
	bias = tf.Variable(tf.zeros(shape=vocab_size+1))
	logits = tf.nn.bias_add(tf.matmul(output,weights),bias=bias)

	if output_data is not None:
		#output_data must be one-hot encode
		labels = tf.one_hot(tf.reshape(output_data,[-1]),depth=vocab_size+1)
		#should be[?,vocab_size+1]
		loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels,logits=logits)
		#loss shape should be [?,vocab_size+!]
		total_loss = tf.reduce_mean(loss)
		train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

		end_points['initial_state'] = initial_state
		end_points['output'] = output
		end_points['train_op'] = train_op
		end_points['total_loss'] = total_loss
		end_points['loss']=loss
		end_points['last_state']=last_state
	else:
		prediction = tf.nn.softmax(logits)
		end_points['initial_state'] = initial_state
		end_points['last_state'] = last_state
		end_points['prediction'] = prediction
	return end_points

4.训练模型

我们需要对模型进行保存,比如每6个epoch保存一下,以方便我们后面测试用,如果这里不保存模型,每6次保存下来的模型是不同的,生成的唐诗风格也不同,其实也意味着最后一个保存的模型不一定是最好的。
def run_training():
    if not os.path.exists(os.path.dirname(FLAGS.checkpoints_dir)):
        os.mkdir(os.path.dirname(FLAGS.checkpoints_dir))
    if not os.path.exists(FLAGS.checkpoints_dir):
        os.mkdir(FLAGS.checkpoints_dir)

    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("[INFO] restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('[INFO] start training...')
        try:
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                n_chunk = len(poems_vector) // FLAGS.batch_size
                for batch in range(n_chunk):
                    loss, _, _ = sess.run([
                        end_points['total_loss'],
                        end_points['last_state'],
                        end_points['train_op']
                    ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
                    n += 1
                    print('[INFO] Epoch: %d , batch: %d , training loss: %.6f' % (epoch, batch, loss))
                    
                if epoch % 6 == 0:
                    saver.save(sess, './model/', global_step=epoch)
                    #saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)
        except KeyboardInterrupt:
            print('[INFO] Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)
            print('[INFO] Last epoch were saved, next time will start from epoch {}.'.format(epoch))

5.测试模型

这里看起来简单,其实没那么容易,我们需要注意藏头诗,自然是利用每一个字进行生成一句诗,我也是根据之前看视频写的代码,主要改了这一部分,然后变成藏头诗
def gen_poem(begin_words,type):
    if type !=5 and type != 7:
        print('The second para has to be 5 or 7!')
        return
    batch_size = 1
    print('[INFO] loading corpus from %s' % FLAGS.file_path)
    poems_vector, word_int_map, vocabularies = process_poems(FLAGS.file_path)

    input_data = tf.placeholder(tf.int32, [batch_size, None])

    end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        #checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
        checkpoint = tf.train.latest_checkpoint('./model/')
        #saver.restore(sess, checkpoint)
        saver.restore(sess, './model/-36')
        poem = ''
        for head in begin_words:
            flag = True
            while flag:
                # state_ = sess.run(cell.zero_state(1, tf.float32))
                x = np.array([list(map(word_int_map.get, start_token))])
                # print('x=',x)

                [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                                 feed_dict={input_data: x})
                # if begin_word:
                #     word = begin_word
                # else:
                #     word = to_word(predict, vocabularies)
                sentence = head
                x = np.zeros((1, 1))
                x[0, 0] = word_int_map[sentence]
                [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                                 feed_dict={input_data: x, end_points['initial_state']: last_state})
                word = to_word(predict, vocabularies)
                sentence += word
                print(sentence)

                while word != u'。':
                    print ('running')
                    # poem += word
                    # print(word)
                    x = np.zeros((1, 1))
                    x[0, 0] = word_int_map[word]
                    # print('x2=', x)
                    [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                                     feed_dict={input_data: x, end_points['initial_state']: last_state})
                    word = to_word(predict, vocabularies)
                    sentence += word
                    print(sentence,len(sentence))

                if len(sentence) == 2 + 2 * type:
                    sentence += '\n'
                    poem += sentence
                    flag = False
                    # print('word=', word)
                # word = words[np.argmax(probs_)]
        return poem

结果如下图所示:

执遮白云法,山川深梦里。
子宓心越思,山欲下童儿。
之子自不索,朱师能识诗。
手计似家路,知与休火真。
与君归去第,西令灵六千。
子黏若藏文,葩川不上文。
偕篮惊此鹤,犹真似被鸡。
老邑人间衲,一名养论化。





 

猜你喜欢

转载自blog.csdn.net/pursue_myheart/article/details/81044403