Implementation of seq2seq (2)

Book continues above:

This article implements the method (3). The collection method of Encoder's hidden layer and Embeding is sum, of course, it can also be changed to other methods. The code is directly below:


    def build_model(self):
 
        encoder_input = layers.Input(shape=(self.input_seq_len,))
        encoder_embeding = layers.Embedding(input_dim=len(self.en_word_id_dict),
                                            output_dim=self.encode_embeding_len,
                                            mask_zero=True
                                            )(encoder_input)
        encoder_lstm, state_h, state_c = layers.LSTM(units=self.encode_embeding_len,
                                                     return_state=True)(encoder_embeding)

        encoder_state = [state_h, state_c]

        decoder_input = layers.Input(shape=(self.output_seq_len,))
        decoder_embeding = layers.Embedding(input_dim=len(self.ch_word_id_dict),
                                            output_dim=self.decode_embeding_len,
                                            mask_zero=True
                                            )(decoder_input)
        #embeding 和 中间状态平拼接
        encoder_lstm = layers.core.RepeatVector(self.output_seq_len)(encoder_lstm)
        # decoder_embeding = layers.merge([encoder_lstm, decoder_embeding], mode="concat", concat_axis=-1)
        decoder_embeding = layers.merge([encoder_lstm, decoder_embeding], mode="sum")

        decoder_lstm, _, _ = layers.LSTM(units=self.encode_embeding_len,
                                         return_state=True,
                                         return_sequences=True)(decoder_embeding, initial_state=encoder_state)
        decoder_out = layers.Dense(len(self.ch_word_id_dict), activation="softmax")(decoder_lstm)

        model = Model([encoder_input, decoder_input], decoder_out)
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

        model.summary()
        return model

Guess you like

Origin blog.csdn.net/cyinfi/article/details/88375671