Transformer解码器原理解析

Transformer解码器原理
在这里插入图片描述
解码器层

import torch
import torch.nn as nn
class DecoderLayer(nn.Module):
        def __init__(self,size,self_attn,src_attn,dropout):
                super().__init__()
                self.size=size
                self.self_attn=self_attn
                self.src_attn=src_attn
                self.feed_forward=feed_forward
                self.sublayer=clones(SubLayerconnection(),3)

        forward(input,memory,source_mask,target_mask):
                m=memory
                input=self.sublayer[0](input,lambda input:self.self_attn(input,input,input,target_mask)
                input=self.sublayer[1](input,lambda input:self.src_attn(input,m,m,source_target)
                return self.sublayer[2](input,self.feed_forward)


dl = DecoderLayer(size, self_attn, src_attn, ff, dropout)
dl_result = dl(x, memory, source_mask, target_mask)



解码器

class Decoder(nn.Module):
        def __init__(self,layer,N):
                super().__init__()
                self.layers=clones(layer,N)
                self.norm=NormLayer(layer.size)

        def forward(self,input,memory,source_mask,target_mask):
                for layer in self.layers:

                        input=layer(input,memory,source_mask,target_mask)
                return self.norm(input)

c=copy.deepcopy
attn=MultiHeadedAttention(head,d_model)
feed_forward=PositionalwisefeedForward(d_model,d_ff,dropout)

layer=DecoderLayer(size,c(attn),c(attn),c(feed_forward),dropout)
de=Decoder(layer,N)

output_de=de(input,memory,source_mask,target_mask)

发布了66 篇原创文章 · 获赞 1 · 访问量 7002

猜你喜欢

转载自blog.csdn.net/qq_41128383/article/details/105729576
今日推荐