利用tensorflow构建LSTM预测单词

1 导入库

import os
import io
import re
import requests
import string
import collections
import random
import pickle
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

2 加载数据

data_dir = 'temp'
data_file = 'shakespeare.txt'
model_path = 'shakespeare_model'
full_model_dir = os.path.join(data_dir, model_path)
punctuation = string.punctuation
punctuation = ''.join([x for x in punctuation if x not in ['-',"'"]]) # !"#$%&()*+,./:;<=>?@[\]^_`{|}~

if not os.path.exists(full_model_dir):
    os.makedirs(full_model_dir)
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

print('Loading Shakespeare Data')

if not os.path.isfile(os.path.join(data_dir, data_file)):
    print('Not found, downloading Shakespeare texts from www.gutenberg.org')
    shakespeare_url = 'http://www.gutenberg.org/cache/epub/100/pg100.txt'
    response = requests.get(shakespeare_url)
    shakespeare_file = response.content
    s_text = shakespeare_file.decode('utf-8')
    s_text = s_text[7675:]
    s_text = s_text.replace('\r\n', '')
    s_text = s_text.replace('\n', '')
    
    with open(os.path.join(data_dir, data_file), 'w') as out_conn:
        out_conn.write(s_text)
else:
    with open(os.path.join(data_dir, data_file), 'r') as file_conn:
        s_text = file_conn.read().replace('\n', '')
Loading Shakespeare Data

3 清洗数据

s_text = re.sub(r'[{}]'.format(punctuation), ' ', s_text)
s_text = re.sub('\s+', ' ', s_text).strip().lower()
print(s_text[0:1000])
from fairest creatures we desire increase that thereby beauty's rose might never die but as the riper should by time decease his tender heir might bear his memory but thou contracted to thine own bright eyes feed'st thy light's flame with self-substantial fuel making a famine where abundance lies thy self thy foe to thy sweet self too cruel thou that art now the world's fresh ornament and only herald to the gaudy spring within thine own bud buriest thy content and tender churl mak'st waste in niggarding pity the world or else this glutton be to eat the world's due by the grave and thee 2 when forty winters shall besiege thy brow and dig deep trenches in thy beauty's field thy youth's proud livery so gazed on now will be a tattered weed of small worth held then being asked where all thy beauty lies where all the treasure of thy lusty days to say within thine own deep sunken eyes were an all-eating shame and thriftless praise how much more praise deserved thy beauty's use if thou couldst

4 创建词汇表

min_word_freq = 5
def build_vocab(text, min_word_freq):
    word_counts = collections.Counter(text.split(' '))
    word_counts = {
    
    key : counts for key, counts in word_counts.items() if counts > min_word_freq}
    words = word_counts.keys()
    vocab_to_ix_dict = {
    
    key : (ix+1) for ix, key in enumerate(words)}
    vocab_to_ix_dict['unknown'] = 0
    ix_to_vocab_dict = {
    
    ix : words for words, ix in vocab_to_ix_dict.items()}
    return ix_to_vocab_dict, vocab_to_ix_dict

ix2vocab, vocab2ix = build_vocab(s_text, min_word_freq)
vocab_size = len(ix2vocab)  + 1
print(vocab_size)
8009

5 将文本转换成索引

s_text_words = s_text.split(' ')
s_text_ix = np.array([vocab2ix[word] if word in vocab2ix.keys() else 0 for word in s_text_words]) # 注意将词频小于min_word_seq 表示为0
print(s_text_ix)
[6232 1204  803 ... 3434 6628 1863]

6 LSTM模型

rnn_size = 128  # RNN Model size
epochs = 1  # Number of epochs to cycle through data
batch_size = 100  # Train on this many examples at once
learning_rate = 0.001  # Learning rate
training_seq_len = 50  # how long of a word group to consider
embedding_size = rnn_size  # Word embedding size
save_every = 500  # How often to save model checkpoints
eval_every = 50  # How often to evaluate the test sentences
prime_texts = ['thou art more', 'to be or not to', 'wherefore art thou']
sess = tf.Session()
class LSTM_Model():
    def __init__(self,embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size, infer=False):
        self.rnn_size = rnn_size
        self.vocab_size = vocab_size
        self.infer = infer
        self.learning_rate = learning_rate
        if self.infer:
            self.batch_size = 1
            self.training_seq_len = 1
        else:
            self.batch_size = batch_size
            self.training_seq_len = training_seq_len
        self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
        self.initial_state = self.lstm_cell.zero_state(self.batch_size, tf.float32)
        self.x_data = tf.placeholder(shape=[self.batch_size, self.training_seq_len], dtype=tf.int32)
        self.y_output = tf.placeholder(shape=[self.batch_size, self.training_seq_len], dtype=tf.int32)
        with tf.variable_scope('lstm_vars'):
            W = tf.get_variable('W', shape=[self.rnn_size, self.vocab_size], dtype=tf.float32, initializer=tf.random_normal_initializer())
            b = tf.get_variable('b', shape=[self.vocab_size], dtype=tf.float32, initializer=tf.random_normal_initializer())
            embedding_mat = tf.get_variable('embedding_mat', shape=[self.vocab_size, self.rnn_size], dtype=tf.float32, initializer=tf.random_normal_initializer())
            embedding_output = tf.nn.embedding_lookup(embedding_mat, self.x_data)
            rnn_inputs = tf.split(axis=1, num_or_size_splits=self.training_seq_len, value=embedding_output) # 将embedding_output在维度1熵切分成train_seq_len个
            rnn_inputs_trimmed = [tf.squeeze(x, [1]) for x in rnn_inputs] # Removes dimensions of size 1 from the shape of a tensor
        def infered_loop(prev, count):
            prev_transformed = tf.matmul(prev, W) + b
            prev_symbol = tf.stop_gradient(tf.argmax(prev_transformed, 1))
            output = tf.nn.embedding_lookup(embedding_mat, prev_symbol)
            return output
        decoder = tf.contrib.legacy_seq2seq.rnn_decoder
        outputs, last_state = decoder(decoder_inputs=rnn_inputs_trimmed, initial_state=self.initial_state, cell=self.lstm_cell, loop_function = infered_loop if self.infer else None)
        output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, self.rnn_size])
        self.logit_output = tf.matmul(output, W) + b
        self.model_output = tf.nn.softmax(self.logit_output)
        loss_fun = tf.contrib.legacy_seq2seq.sequence_loss_by_example
        loss = loss_fun([self.logit_output], [tf.reshape(self.y_output, [-1])], [tf.ones([self.batch_size * self.training_seq_len])])
        self.cost = tf.reduce_sum(loss) / (self.batch_size * self.training_seq_len)
        self.final_state = last_state
        gradients, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tf.trainable_variables()), 4.5)
        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = optimizer.apply_gradients(zip(gradients, tf.trainable_variables()))
    def sample(self, sess, words=ix2vocab, vocab=vocab2ix, num=10, prime_text='thou art'):
        state = sess.run(self.lstm_cell.zero_state(1, tf.float32))
        word_list = prime_text.split()
        for word in word_list[:-1]:
            x = np.zeros((1,1))
            x[0,0] = vocab[word]
            feed_dict = {
    
    self.x_data:x, self.initial_state:state}
            [state] = sess.run([self.final_state], feed_dict=feed_dict)
        out_sentence = prime_text
        word = word_list[-1]
        for n in range(num):
            x = np.zeros((1,1))
            x[0,0] = vocab[word]
            feed_dict = {
    
    self.x_data:x, self.initial_state:state}
            [model_output, state] = sess.run([self.model_output, self.final_state], feed_dict=feed_dict)
            sample = np.argmax(model_output[0])
            if sample == 0:
                break
            word = words[sample]
            out_sentence = out_sentence + ' ' + word
        return out_sentence

7 声明LSTM模型及其测试模型

# LSTM模型
lstm_model = LSTM_Model(embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size)

# 测试模型
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    test_lstm_model = LSTM_Model(embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size, infer=True)
WARNING:tensorflow:From <ipython-input-6-0d3f7347e00d>:23: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').

8 Saver

saver = tf.train.Saver()

9 分割输入文本

num_batches = int(len(s_text_ix)/(batch_size*training_seq_len)) + 1
batches = np.array_split(s_text_ix, num_batches)
batches = [np.resize(x, [batch_size, training_seq_len]) for x in batches]

10 初始化变量

init = tf.global_variables_initializer()
sess.run(init)

11 训练

train_loss = []
iteration_count = 1
for epoch in range(epochs):
    random.shuffle(batches) # 打乱数据
    targets = [np.roll(x, -1, axis=1) for x in batches] # np.roll(x,shift,axis) (将a,沿着axis的方向,滚动shift长度)
    print('Starting Epoch #{} of {}.'.format(epoch+1, epochs))
    # Reset initial LSTM state every epoch
    state = sess.run(lstm_model.initial_state)
    for ix, batch in enumerate(batches):
        training_dict = {
    
    lstm_model.x_data: batch, lstm_model.y_output: targets[ix]}
        c, h = lstm_model.initial_state
        training_dict[c] = state.c
        training_dict[h] = state.h
        
        temp_loss, state, _ = sess.run([lstm_model.cost, lstm_model.final_state, lstm_model.train_op], feed_dict=training_dict)
        train_loss.append(temp_loss)
        
        # Print status every 10 gens
        if iteration_count % 10 == 0:
            summary_nums = (iteration_count, epoch+1, ix+1, num_batches+1, temp_loss)
            print('Iteration: {}, Epoch: {}, Batch: {} out of {}, Loss: {:.2f}'.format(*summary_nums))
        
        # Save the model and the vocab
        if iteration_count % save_every == 0:
            # Save model
            model_file_name = os.path.join(full_model_dir, 'model')
            saver.save(sess, model_file_name, global_step=iteration_count)
            print('Model Saved To: {}'.format(model_file_name))
            # Save vocabulary
            dictionary_file = os.path.join(full_model_dir, 'vocab.pkl')
            with open(dictionary_file, 'wb') as dict_file_conn:
                pickle.dump([vocab2ix, ix2vocab], dict_file_conn)
        
        if iteration_count % eval_every == 0:
            for sample in prime_texts:
                print(test_lstm_model.sample(sess, ix2vocab, vocab2ix, num=10, prime_text=sample))
                
        iteration_count += 1   
Starting Epoch #1 of 1.
Iteration: 10, Epoch: 1, Batch: 10 out of 182, Loss: 10.30
Iteration: 20, Epoch: 1, Batch: 20 out of 182, Loss: 9.38
Iteration: 30, Epoch: 1, Batch: 30 out of 182, Loss: 8.99
Iteration: 40, Epoch: 1, Batch: 40 out of 182, Loss: 8.62
Iteration: 50, Epoch: 1, Batch: 50 out of 182, Loss: 8.41
thou art more than wide to
to be or not to the
wherefore art thou bassanio's bassanio's master a
Iteration: 60, Epoch: 1, Batch: 60 out of 182, Loss: 7.98
Iteration: 70, Epoch: 1, Batch: 70 out of 182, Loss: 7.98
Iteration: 80, Epoch: 1, Batch: 80 out of 182, Loss: 7.70
Iteration: 90, Epoch: 1, Batch: 90 out of 182, Loss: 7.63
Iteration: 100, Epoch: 1, Batch: 100 out of 182, Loss: 7.01
thou art more than wide to
to be or not to
wherefore art thou canst bassanio's a
Iteration: 110, Epoch: 1, Batch: 110 out of 182, Loss: 7.09
Iteration: 120, Epoch: 1, Batch: 120 out of 182, Loss: 7.10
Iteration: 130, Epoch: 1, Batch: 130 out of 182, Loss: 7.24
Iteration: 140, Epoch: 1, Batch: 140 out of 182, Loss: 6.74
Iteration: 150, Epoch: 1, Batch: 150 out of 182, Loss: 6.76
thou art more than than to
to be or not to the
wherefore art thou canst fall'n sycorax clown canst clown piteous fran base
Iteration: 160, Epoch: 1, Batch: 160 out of 182, Loss: 6.65
Iteration: 170, Epoch: 1, Batch: 170 out of 182, Loss: 6.60
Iteration: 180, Epoch: 1, Batch: 180 out of 182, Loss: 6.82

12 绘图

plt.plot(train_loss, 'k-')
plt.title('Sequence to Sequence Loss')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()     

png

# 小例子
a = [[1,2,3], [2,3,2]]
b = tf.split(a, 3,1)
sess = tf.Session()
print(sess.run((b)))
print(sess.run(tf.squeeze(b,[2])))
[array([[1],
       [2]]), array([[2],
       [3]]), array([[3],
       [2]])]
[[1 2]
 [2 3]
 [3 2]]
c = np.array_split(a, 2)
c
[array([[1, 2, 3]]), array([[2, 3, 2]])]
x = np.arange(12).reshape(3,4)  # x例子
np.roll(x, -1, axis=1) 
array([[ 1,  2,  3,  0],
       [ 5,  6,  7,  4],
       [ 9, 10, 11,  8]])

在这里插入图片描述

おすすめ

転載: blog.csdn.net/qq_40006058/article/details/91795147