Partial implementation of Transformer model decoder

Note: Some content comes from online tutorials, if there is any infringement, please contact me to delete

Tutorial link: 2.4.2 Decoder-part2_哔哩哔哩_bilibili

1. The role of the decoder layer

As the constituent unit of the decoder, each decoder layer performs feature extraction operations toward the target direction according to the given input, that is, the decoding process.

Code:

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
    '''
        size: 词嵌入维度大小,也代表解码器的尺寸
        self_attn: 多头自注意力对象,也就是说这个注意力机制需要 Q=K=V
        src_attn: 多头注意力对象,这里 Q!=K=V
        feed_forward: 前馈全连接层
    '''
        super(DecoderLayer, self).__init__()
    
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self。dropout = dropout
        self.sublayer = clones(SublayerConnection(size, dropout),3)

    def forward(self, x, memory, source_mask, target_mask):
    '''
        x: 上一层的输入
        memory: 来自编码器的语义存储变量memory
        source_mask: 源数据掩码张量
        target_mask: 目标数据掩码张量
    '''
        m = memory
        # 使用target_mask,为了将编码时未来的信息遮掩
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, target_mask))
        # 使用source_mask,为了遮掩掉对结果信息无用的数据
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, source_mask))
        # 最终输出由编码器输入和目标数据一同作用的特征提取结果
        return self.sublayer[2](x, self.feed_forward)
    
    

2. Decoder function

Based on the result of the encoder and the result of the previous prediction, the next possible value is represented by a feature

Code implementation: (actually a stack of decoder layers)

class Decoder(nn.Module):    
    def __init__(self, Layer, N):
        
        super(Decoder, self).__init__()
        
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer, size)

    def forward(self, x, memory, source_mask, target_mask):
        
        for layer in self.layers:
            x = layer(x, memory, source_mask, target_mask)
        # 输出解码过程的最终表示 
        return self.norm(x)
    

Guess you like

Origin blog.csdn.net/APPLECHARLOTTE/article/details/127323269