Transformer解码器原理
解码器层
import torch
import torch.nn as nn
class DecoderLayer(nn.Module):
def __init__(self,size,self_attn,src_attn,dropout):
super().__init__()
self.size=size
self.self_attn=self_attn
self.src_attn=src_attn
self.feed_forward=feed_forward
self.sublayer=clones(SubLayerconnection(),3)
forward(input,memory,source_mask,target_mask):
m=memory
input=self.sublayer[0](input,lambda input:self.self_attn(input,input,input,target_mask)
input=self.sublayer[1](input,lambda input:self.src_attn(input,m,m,source_target)
return self.sublayer[2](input,self.feed_forward)
dl = DecoderLayer(size, self_attn, src_attn, ff, dropout)
dl_result = dl(x, memory, source_mask, target_mask)
解码器
class Decoder(nn.Module):
def __init__(self,layer,N):
super().__init__()
self.layers=clones(layer,N)
self.norm=NormLayer(layer.size)
def forward(self,input,memory,source_mask,target_mask):
for layer in self.layers:
input=layer(input,memory,source_mask,target_mask)
return self.norm(input)
c=copy.deepcopy
attn=MultiHeadedAttention(head,d_model)
feed_forward=PositionalwisefeedForward(d_model,d_ff,dropout)
layer=DecoderLayer(size,c(attn),c(attn),c(feed_forward),dropout)
de=Decoder(layer,N)
output_de=de(input,memory,source_mask,target_mask)