transform decoder

1 decoder 层

1.1 decoder 子层

class DecoderLayer(nn.Module):
    def __init__(self,embedding_dim,self_attn,src_attn,feed_forward,dropout):
        super(DecoderLayer,self).__init__()
        self.embedding_dim=embedding_dim
        self.self_attn=self_attn
        self.src_attn=src_attn
        self.feed_forward=feed_forward
        self.dropout=nn.Dropout(dropout)        
        self.layers=copy_model(SublayerConnection(embedding_dim),3)
    def forward(self,x,memory,source_mask,target_mask):
        x=self.layers[0](x,lambda x :self.self_attn(x,x,x,target_mask))
        x=self.layers[1](x,lambda x :self.src_attn(x,memory,memory,source_mask))
        return self.layers[2](x,self.feed_forward)
self_attn=src_attn=MultiHeadAttention(head,embedding_dim,dropout)
d_ff=64
ff=PositionwiseFeedForwaed(embedding_dim,d_ff,dropout)
layer=EncoderLayer(embedding_dim,self_attn,ff,dropout)
memory=en_result
x=pe_result
mask=Variable(torch.zeros(8,4,4))
source_mask=target_mask=mask

dl=DecoderLayer(embedding_dim,self_attn,src_attn,ff,dropout)
dl(x,memory,source_mask,target_mask)       

1.2 decoder

class Decoder(nn.Module):
    def __init__(self,layer,n):
        super(Decoder,self).__init__()
        self.layers=copy_model(layer,n)
        self.norm=LayerNorm(layer.embedding_dim)
    def forward(self,x,memory,source_mask,target_mask):
        for layer in self.layers:
            x=layer(x,memory,source_mask,target_mask)
        return self.norm(x)
attn=MultiHeadAttention(head,embedding_dim,dropout)
d_ff=64
ff=PositionwiseFeedForwaed(embedding_dim,d_ff,dropout)
layer=DecoderLayer(embedding_dim,attn,attn,ff,dropout)
memory=en_result
x=pe_result
mask=Variable(torch.zeros(8,4,4))
source_mask=target_mask=mask

de=Decoder(layer,n)
de_result=de(x,memory,source_mask,target_mask)

2 输出层

class Generate(nn.Module):
    def __init__(self,embedding_dim,vocab_size):
        super(Generate,self).__init__()
        self.liner=nn.Linear(embedding_dim,vocab_size)
    def forward(self,x):
        x=self.liner(x)
        return torch.log_softmax(x,dim=-1)
x=de_result
gen=Generate(embedding_dim,vocab_size)
gen_result=gen(x)

3 encoder decoder 联合

class EncoderDecoder(nn.Module):
    def __init__(self,encoder,decoder,source_embed,target_embed,generator):
        super(EncoderDecoder,self).__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.source_embed=source_embed
        self.target_embed=target_embed
    def forward(self,source,target,source_mask,target_mask):
        encoder=self.encode(source,source_mask)
        return self.decode(encoder,target,target_mask,source_mask)
    def encode(self,source,source_mask):
        return self.encoder(self.source_embed(source),source_mask)
    def decode(self,memory,target,target_mask,source_mask):
        return self.decoder(self.target_embed(target),memory,source_mask,target_mask)
        
encoder=en
decoder=de
source_embed=nn.Embedding(vocab_size,embedding_dim)
target_embed=nn.Embedding(vocab_size,embedding_dim)
generator=gen
source=target=Variable(torch.LongTensor([[23,12,41,12],[3,12,41,12]]))
source_mask=target_mask=Variable(torch.zeros(8,4,4))
ed=EncoderDecoder(encoder,decoder,source_embed,target_embed,generator)
ed_result=ed(source,target,source_mask,target_mask)  
ed_result```

4、transform 整合


def make_model(source_vocab_size,target_vocab_size,N=6,embedding_dim=512,d_ff=2048,head=8,dropout=0.1):
    c=copy.deepcopy
    self_attn=MultiHeadAttention(head,embedding_dim,dropout)
 
    ff=PositionwiseFeedForwaed(embedding_dim,d_ff,dropout)
    layer=EncoderLayer(embedding_dim,self_attn,ff,dropout)
    
    position=PositionEncoding(embedding_dim,dropout)
    
    model=EncoderDecoder(
        Encoder(EncoderLayer(embedding_dim,c(self_attn),c(ff),dropout),N),
        Decoder(DecoderLayer(embedding_dim,c(self_attn),c(self_attn),c(ff),dropout),N),
        nn.Sequential(Embeddings(embedding_dim,source_vocab_size),c(position)),
        nn.Sequential(Embeddings(embedding_dim,target_vocab_size),c(position)),
        Generate(embedding_dim,target_vocab_size)
    )
    for p in model.parameters():
        if p.dim()>1:
            nn.init.xavier_uniform(p)
    return model
source_vocab_size=11
target_vocab_size=11
make_model(source_vocab_size,target_vocab_size)  

猜你喜欢

转载自blog.csdn.net/weixin_42529756/article/details/120319755
今日推荐