Seq2Seq模型应用案例

版权声明:王家林大咖2018年新书《SPARK大数据商业实战三部曲》清华大学出版,微信公众号:从零起步学习人工智能 https://blog.csdn.net/duan_zhihua/article/details/87114665

Seq2Seq模型应用案例( :

Seq2Seq是Encoder-Decoder(编码器与解码器)模型,输入是一个序列,输出也是一个序列,适用于输入序列与输出序列长度不等长的场景,如机器翻译、人机对话、聊天机器人等。

本案例参考网友的github代码,实现一个Seq2Seq对话的简单例子,训练模型的语料如下:

编码器语句序列:

    ['Hi What is your name?', 'Nice to meet you!'],
    ['Which programming language do you use?', 'See you later.'],
    ['Where do you live?', 'What is your major?'],
    ['What do you want to drink?', 'What is your favorite beer?']]

解码器语句序列:

    ['Hi this is Jaemin.', 'Nice to meet you too!'],
    ['I like Python.', 'Bye Bye.'],
    ['I live in Seoul, South Korea.', 'I study industrial engineering.'],
    ['Beer please!', 'Leffe brown!']]

Seq2Seq模型代码:

一,创建一个配置类,配置类里面放各种超参数。内核为BasicLSTMCell。 
    hidden_size = 30  
    enc_emb_size = 30
    enc_emb_size = 30  
    
二,创建Seq2SeqModel模型类:

1,__init__初始化

2,add_placeholders  数据占位符 
   self.enc_inputs : shape=[None, enc_sentence_length]   (编码器句子批次大小,编码器的句子最大长度)
   self.enc_sequence_lengt :shape=[None,]    [编码器句子的批次大小,]
   训练模式下:
   self.dec_inputs :shape=[None, dec_sentence_length+1] (解码器批次大小,解码器的句子最大长度+1 ):因为解码器的句子在进行单词-编码的sent2idx转换时候,每个句子前面加上了一个开始字符'_GO'的数字[0]。
   self.dec_sequence_length  shape=[None,]  [解码器句子的批次大小,]

3,add_encoder 编码器函数  
  self.enc_Wemb  [enc_vocab_size+1, self.enc_emb_size]) 因为编码器句子补全时加上一个PAD字符对应的数字[0]
  enc_emb_inputs :[Batch_size , enc_sent_len , embedding_size]          
  enc_cell = self.cell(self.hidden_size)
  enc_outputs: [batch_size , enc_sent_len , embedding_size]
  enc_last_state: [batch_size , embedding_size] 
  

4,add_decoder解码器 函数 
 self.dec_Wemb :[dec_vocab_size+2, self.dec_emb_size]  加2是因为解码器句子每个句子要加上一个开始字符'_GO'对应的数字[0],以及句子补全的字符'_PAD'对应的数字[1]
 max_dec_len:self.dec_sequence_length+1, 解码器中每个句子长度的最大值加1(加上每个句子的开始符号GO)。
 training_helper = tf.contrib.seq2seq.TrainingHelper(
                    inputs=dec_emb_inputs,
                    sequence_length=self.dec_sequence_length+1,
                    time_major=False,
                    name='training_helper') 
训练模型时使用training_helper,因为我们已经明确了解码器的每个序列,把每个序列的词作为循环输入,而不是将预测的词作为循环输入,提升了模型训练时的准确度。
                             
        
5.预测时略有不同:不需要使用TrainingHelper,使用GreedyEmbeddingHelper方法tf.contrib.seq2seq.GreedyEmbeddingHelper:每一个预测的单词作为RNN的下一个输入单词再进行预测, 每一个单词的预测输出继续作为循环输入,再预测出一个结果!

6. 模型训练:
add_training_op: 使用tf.train.RMSPropOptimizer优化器
save:保存模型
restore:加载模型
summary:输出到tenorboard展示。
build:将add_placeholders、add_encoder、add_decoder组装。
train:训练模型。tqdm用于记录进度条。

问题:Seq2Seq模型在模型训练时测试很好,但是在测试集预测推理时不理想,模型训练与模型预测的差异性如下

  • 训练时:解码器的输入是实际的目标字符。
  • 预测推理时:解码器的输入是模型自己产生的预测字符,可能前几个单词预测不准确,将会误导模型后续单词的预测。

Seq2Seq示意图:   

 Seq2Seq案例的全部代码:https://github.com/duanzhihua/tf_tutorial_plus/blob/master/RNN_seq2seq/contrib_seq2seq/01_TrainingHelper.ipynb

class Seq2SeqModel(object):
    def __init__(self, config, mode='training'):
        assert mode in ['training', 'evaluation', 'inference']
        self.mode = mode

        # Model
        self.hidden_size = config.hidden_size
        self.enc_emb_size = config.enc_emb_size
        self.dec_emb_size = config.dec_emb_size
        self.cell = config.cell
        
        # Training
        self.optimizer = config.optimizer
        self.n_epoch = config.n_epoch
        self.learning_rate = config.learning_rate
        
        # Checkpoint Path
        self.ckpt_dir = config.ckpt_dir
        
    def add_placeholders(self):
        self.enc_inputs = tf.placeholder(
            tf.int32,
            shape=[None, enc_sentence_length],
            name='input_sentences')

        self.enc_sequence_length = tf.placeholder(
            tf.int32,
            shape=[None,],
            name='input_sequence_length')
        
        if self.mode == 'training':
            self.dec_inputs = tf.placeholder(
                tf.int32,
                shape=[None, dec_sentence_length+1],
                name='target_sentences')

            self.dec_sequence_length = tf.placeholder(
                tf.int32,
                shape=[None,],
                name='target_sequence_length')
            
    def add_encoder(self):
        with tf.variable_scope('Encoder') as scope:
            with tf.device('/cpu:0'):
                self.enc_Wemb = tf.get_variable('embedding',
                    initializer=tf.random_uniform([enc_vocab_size+1, self.enc_emb_size]),
                    dtype=tf.float32)

            # [Batch_size x enc_sent_len x embedding_size]
            enc_emb_inputs = tf.nn.embedding_lookup(
                self.enc_Wemb, self.enc_inputs, name='emb_inputs')
            enc_cell = self.cell(self.hidden_size)

            # enc_outputs: [batch_size x enc_sent_len x embedding_size]
            # enc_last_state: [batch_size x embedding_size]
            enc_outputs, self.enc_last_state = tf.nn.dynamic_rnn(
                cell=enc_cell,
                inputs=enc_emb_inputs,
                sequence_length=self.enc_sequence_length,
                time_major=False,
                dtype=tf.float32)
            
    def add_decoder(self):
        with tf.variable_scope('Decoder') as scope:
            with tf.device('/cpu:0'):
                self.dec_Wemb = tf.get_variable('embedding',
                    initializer=tf.random_uniform([dec_vocab_size+2, self.dec_emb_size]),
                    dtype=tf.float32)

            dec_cell = self.cell(self.hidden_size)

            # output projection (replacing `OutputProjectionWrapper`)
            output_layer = Dense(dec_vocab_size+2, name='output_projection')
            
            if self.mode == 'training':

                # maxium unrollings in current batch = max(dec_sent_len) + 1(GO symbol)
                max_dec_len = tf.reduce_max(self.dec_sequence_length+1, name='max_dec_len')

                dec_emb_inputs = tf.nn.embedding_lookup(
                    self.dec_Wemb, self.dec_inputs, name='emb_inputs')
        
                training_helper = tf.contrib.seq2seq.TrainingHelper(
                    inputs=dec_emb_inputs,
                    sequence_length=self.dec_sequence_length+1,
                    time_major=False,
                    name='training_helper')

                training_decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell=dec_cell,
                    helper=training_helper,
                    initial_state=self.enc_last_state,
                    output_layer=output_layer) 

                train_dec_outputs, train_dec_last_state = tf.contrib.seq2seq.dynamic_decode(
                    training_decoder,
                    output_time_major=False,
                    impute_finished=True,
                    maximum_iterations=max_dec_len)
                
                # dec_outputs: collections.namedtuple(rnn_outputs, sample_id)
                # dec_outputs.rnn_output: [batch_size x max(dec_sequence_len) x dec_vocab_size+2], tf.float32
                # dec_outputs.sample_id [batch_size], tf.int32
                
                # logits: [batch_size x max_dec_len x dec_vocab_size+2]
                logits = tf.identity(train_dec_outputs.rnn_output, name='logits')
                
                # targets: [batch_size x max_dec_len x dec_vocab_size+2]
                targets = tf.slice(self.dec_inputs, [0, 0], [-1, max_dec_len], 'targets')
                
                # masks: [batch_size x max_dec_len]
                # => ignore outputs after `dec_senquence_length+1` when calculating loss
                masks = tf.sequence_mask(self.dec_sequence_length+1, max_dec_len, dtype=tf.float32, name='masks')
                
                # Control loss dimensions with `average_across_timesteps` and `average_across_batch`
                # internal: `tf.nn.sparse_softmax_cross_entropy_with_logits`
                self.batch_loss = tf.contrib.seq2seq.sequence_loss(
                    logits=logits,
                    targets=targets,
                    weights=masks,
                    name='batch_loss')
                
                # prediction sample for validation
                self.valid_predictions = tf.identity(train_dec_outputs.sample_id, name='valid_preds')

                # List of training variables
                # self.training_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            
            elif self.mode == 'inference':
            
                batch_size = tf.shape(self.enc_inputs)[0:1]
                start_tokens = tf.zeros(batch_size, dtype=tf.int32)
            
                inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                    embedding=self.dec_Wemb,
                    start_tokens=start_tokens,
                    end_token=1)
                
                inference_decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell=dec_cell,
                    helper=inference_helper,
                    initial_state=self.enc_last_state,
                    output_layer=output_layer)
                
                infer_dec_outputs, infer_dec_last_state = tf.contrib.seq2seq.dynamic_decode(
                    inference_decoder,
                    output_time_major=False,
                    impute_finished=True,
                    maximum_iterations=dec_sentence_length)
                
                # [batch_size x dec_sentence_length], tf.int32
                self.predictions = tf.identity(infer_dec_outputs.sample_id, name='predictions')
                # equivalent to tf.argmax(infer_dec_outputs.rnn_output, axis=2, name='predictions')
                
                # List of training variables
                # self.training_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        
    def add_training_op(self):
        self.training_op = self.optimizer(self.learning_rate, name='training_op').minimize(self.batch_loss)
        
    def save(self, sess, var_list=None, save_path=None):
        print(f'Saving model at {save_path}')
        if hasattr(self, 'training_variables'):
            var_list = self.training_variables
        saver = tf.train.Saver(var_list)
        saver.save(sess, save_path, write_meta_graph=False)
        
    def restore(self, sess, var_list=None, ckpt_path=None):
        if hasattr(self, 'training_variables'):
            var_list = self.training_variables
        self.restorer = tf.train.Saver(var_list)
        self.restorer.restore(sess, ckpt_path)
        print('Restore Finished!')
        
    def summary(self):
        summary_writer = tf.summary.FileWriter(
            logdir=self.ckpt_dir,
            graph=tf.get_default_graph())
        
    def build(self):
        self.add_placeholders()
        self.add_encoder()
        self.add_decoder()

    def train(self, sess, data, from_scratch=False,
              load_ckpt=None, save_path=None):
        
        # Restore Checkpoint
        if from_scratch is False and os.path.isfile(load_ckpt):
            self.restore(sess, load_ckpt)
    
        # Add Optimizer to current graph
        self.add_training_op()
        
        sess.run(tf.global_variables_initializer())
        
        input_batches, target_batches = data
        loss_history = []
        
        for epoch in tqdm(range(self.n_epoch)):

            all_preds = []
            epoch_loss = 0
            for input_batch, target_batch in zip(input_batches, target_batches):
                input_batch_tokens = []
                target_batch_tokens = []
                enc_sentence_lengths = []
                dec_sentence_lengths = []

                for input_sent in input_batch:
                    tokens, sent_len = sent2idx(input_sent)
                    input_batch_tokens.append(tokens)
                    enc_sentence_lengths.append(sent_len)

                for target_sent in target_batch:
                    tokens, sent_len = sent2idx(target_sent,
                                 vocab=dec_vocab,
                                 max_sentence_length=dec_sentence_length,
                                 is_target=True)
                    target_batch_tokens.append(tokens)
                    dec_sentence_lengths.append(sent_len)
       
                # Evaluate 3 ops in the graph
                # => valid_predictions, loss, training_op(optimzier)
                batch_preds, batch_loss, _ = sess.run(
                    [self.valid_predictions, self.batch_loss, self.training_op],
                    feed_dict={
                        self.enc_inputs: input_batch_tokens,
                        self.enc_sequence_length: enc_sentence_lengths,
                        self.dec_inputs: target_batch_tokens,
                        self.dec_sequence_length: dec_sentence_lengths,
                    })
                # loss_history.append(batch_loss)
                epoch_loss += batch_loss
                all_preds.append(batch_preds)
                
            loss_history.append(epoch_loss)

            # Logging every 400 epochs
            if epoch % 400 == 0:
                print('Epoch', epoch)
                for input_batch, target_batch, batch_preds in zip(input_batches, target_batches, all_preds):
                    for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):
                        print(f'\tInput: {input_sent}')
                        print(f'\tPrediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))
                        print(f'\tTarget:, {target_sent}')
                print(f'\tepoch loss: {epoch_loss:.2f}\n')
                
        if save_path:
            self.save(sess, save_path=save_path)

        return loss_history
    
    def inference(self, sess, data, load_ckpt):
        
        self.restore(sess, ckpt_path=load_ckpt)
        
        input_batch, target_batch = data

        batch_preds = []
        batch_tokens = []
        batch_sent_lens = []
        
        for input_sent in input_batch:
            tokens, sent_len = sent2idx(input_sent)
            batch_tokens.append(tokens)
            batch_sent_lens.append(sent_len)
            
        batch_preds = sess.run(
            self.predictions,
            feed_dict={
                self.enc_inputs: batch_tokens,
                self.enc_sequence_length: batch_sent_lens,
            })

        for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):
            print('Input:', input_sent)
            print('Prediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))
            print('Target:', target_sent, '\n')

猜你喜欢

转载自blog.csdn.net/duan_zhihua/article/details/87114665