86.编码器-解码器架构以及代码实现

1. 重新考察CNN

在这里插入图片描述

2. 重新考察RNN

在这里插入图片描述

3. 编码器-解码器架构

在这里插入图片描述

4. 总结

  • 使用编码器-解码器架构的模型,编码器负责表示输入,解码器负责输出

5. 代码实现

5.1 编码器

在编码器接口中,我们只指定长度可变的序列作为编码器的输入X。 任何继承这个Encoder基类的模型将完成代码实现。

from torch import nn

class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError

5.2 解码器

在下面的解码器接口中,我们新增一个init_state函数, 用于将编码器的输出(enc_outputs)转换为编码后的状态。 注意,此步骤可能需要额外的输入,例如:输入序列的有效长度。 为了逐个地生成长度可变的词元序列, 解码器在每个时间步都会将输入 (例如:在前一时间步生成的词元)和编码后的状态 映射成当前时间步的输出词元。

class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    # enc_outputs是encoder所有的输出
    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    # state一开始是从encoder拿过来,之后不断更新
    # X是额外的输入
    def forward(self, X, state):
        raise NotImplementedError

5.3 合并编码器和解码器

总而言之,“编码器-解码器”架构包含了一个编码器和一个解码器, 并且还拥有可选的额外的参数。 在前向传播中,编码器的输出用于生成编码状态, 这个状态又被解码器作为其输入的一部分。

class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
    	# enc_X 是encoder的输入
        enc_outputs = self.encoder(enc_X, *args)
        # 把encoder的输出拿到解码器的init_state中,变成了解码器的初始状态
        dec_state = self.decoder.init_state(enc_outputs, *args)
        # 再把中间状态dec_state和decoder的输入dec_X传入解码器
        return self.decoder(dec_X, dec_state)

“编码器-解码器”体系架构中的术语状态 会启发人们使用具有状态的神经网络来实现该架构。 在下一节中,我们将学习如何应用循环神经网络, 来设计基于“编码器-解码器”架构的序列转换模型。

猜你喜欢

转载自blog.csdn.net/weixin_47505105/article/details/128729706