[pytorch] Pytorch implementation of Transformer-simple translation


Paper address: https://arxiv.org/pdf/1706.03762.pdf

Code reference: https://wmathor.com/index.php/archives/1455/

Note: This code does not contain a dropout layer when building the Transformer model.

Data preprocessing

Two pairs of German → English translated sentences are used, and the index of each word is manually hard-coded to reduce the difficulty of code reading. Construct the input of the encoder, the input of the decoder, and the output of the decoder, which are the real labels.

import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

# S: 开始标志
# E: 结束标志
# P: 如果当前批处理数据长度小于最大长度(自己设置的),将填充空白字符
sentences = [
    # enc_input 编码端输入       dec_input 解码端输入    dec_output 解码端的真实标签
    ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
    ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# 构建源数据词表和目标数据词表
# Padding Should be Zero
src_vocab = {
    
    'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4, 'cola': 5}
src_vocab_size = len(src_vocab)  # 6

tgt_vocab = {
    
    'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'coke': 5, 'S': 6, 'E': 7, '.': 8}
tgt_vocab_size = len(tgt_vocab)  # 9

# 索引转化为单词:{0:'P',1:'i',2:'want',...,8:'.'},用于预测
idx2word = {
    
    i: w for i, w in enumerate(tgt_vocab)}  # i是index,w是key

src_len = 5  # enc_input max sequence length
tgt_len = 6  # dec_input(=dec_output) max sequence length


# 构建编码器输入enc_inputs,解码器输入dec_inputs,解码器输出dec_outputs即真实标签
def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
        enc_input = [[src_vocab[n] for n in sentences[i][0].split()]]
        dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]]
        dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]]

        enc_inputs.extend(enc_input)  # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
        dec_inputs.extend(dec_input)  # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
        dec_outputs.extend(dec_output)  # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)


enc_inputs, dec_inputs, dec_outputs = make_data(sentences)  # 输出为张量
# enc_inputs: [batch_size, src_len]=[2,5]
# dec_inputs/dec_outputs: [batch_size, tgt_len]=[2,6]

class MyDataSet(Data.Dataset):
    def __init__(self, enc_inputs, dec_inputs, dec_outputs):
        super(MyDataSet, self).__init__()
        self.enc_inputs = enc_inputs
        self.dec_inputs = dec_inputs
        self.dec_outputs = dec_outputs

    def __len__(self):
        return self.enc_inputs.shape[0]  # 2

    def __getitem__(self, idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]


# 由于只有两个句子,这里batch_size设置为2
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), batch_size=2, shuffle=True)

Positional Encoding

The change of each position is as follows:
PE ( pos , 2 i ) = sin ( pos / 1000 0 2 i / dmodel ) PE ( pos , 2 i + 1 ) = cos ( pos / 1000 0 2 i / dmodel ) PE_ {(pos,2i)}=sin(pos/10000^ {2i/d_ {model}}) \\ PE_{(pos,2i+1)}=cos(pos/10000^ {2i/d_ {model}} )PE( p o s , 2 i )=s i n ( p o s / 1 0 0 0 02 i / dmodel)PE( p o s , 2 i + 1 ) .=cos(pos/100002 i / dmodel)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # 初始化pe
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # 构建pos
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数用sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数用cos
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: 词向量序列[seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

Draw a picture here to see the position encoding:

import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
pe = PositionalEncoding(20, 0)
y = pe.forward((torch.zeros(100, 1, 20)))
plt.plot(np.arange(100), y[:, 0, 4:8].data.numpy())
plt.legend(["dim %d"%p for p in [4,5,6,7]])
None

quxian

Model parameters

d_model = 512  # Embedding Size
d_ff = 2048  # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention

get_attn_pad_mask

def get_attn_pad_mask(seq_q, seq_k):
    """
    seq_q: [batch_size, len_q]
    seq_k: [batch_size, len_k]
     seq_q 和 seq_k 不一定一致,len_q与len_k可能不相等
    """
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token, 为0设置为 True
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k] 只使用seq_k的pad信息
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

Print it out and see the effect:

dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) 
# 在交互注意力层,只用到了enc_inputs的pad信息,没有用到解码端的pad信息
dec_enc_attn_mask

# 输出 shape:(2,6,5)
tensor([[[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])

get_attn_subsequence_mask

# 解码端 Masked Multi-Head Attention 的 Masked来源,便于并行计算
def get_attn_subsequence_mask(seq):
    """
    seq: 输入的是dec_inputs [batch_size, tgt_len]
    """
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]  # [batch_size, tgt_len, tgt_len]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # 上三角为1的矩阵,k=1设置对角线元素为0
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()  # 变为张量
    return subsequence_mask  # [batch_size, tgt_len, tgt_len]

Print it out and see the effect:

get_attn_subsequence_mask(dec_inputs)

# 输出 shape:(2,6,6)
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]],

        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
import matplotlib.pyplot as plt
a = torch.randn((5, 20))  # 随机生成标准正态分布数 [batch_size, len]
plt.figure(figsize=(5, 5))
plt.imshow(get_attn_subsequence_mask(a)[0])  # [batch_size, len, len] 显示第0个
None

Scaled Dot-Product Attention

Scaled Dot-Product Attention is part of Multi-Head Attention.

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QKT)V

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        """
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        """
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # [,,len_q,d_k]*[,,d_k,len_k]=[,,len_q,len_k]
        # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9)  # mask is True的位置设置为负无穷,经过softmax后为0
        # 掩码 attn_mask 与 scores 的维度相同 [batch_size, n_heads, len_q, len_k]

        attn = nn.Softmax(dim=-1)(scores)  # [batch_size, n_heads, len_q, len_k]
        # attn 为经过softmax之后的相似概率分布,每一行概率和为1
        context = torch.matmul(attn, V)   # [,,len_q,len_k]*[,,len_v(=len_k),d_v]=[,,len_q,d_v]
        # context: QKV经过自注意力机制计算后的值, [batch_size, n_heads, len_q, d_v]
        return context, attn

Multi-Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask):
        """
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        :return: 经过多头注意力+残差+LayerNorm后的输出,保持和input_Q相同的维度
        """
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
        # attn_mask : [batch_size, n_heads, seq_len, seq_len]
        # repeat(): 在第2维复制n_heads次,在其他维是1次。

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v)  # 这一步是图中的 cancat
        # context: [batch_size, len_q, n_heads * d_v]

        output = self.fc(context)  # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).cuda()(output + residual), attn  # 经过残差和LayerNorm不改变维度

Feed Forward Net

# 前馈神经网络,输入输出维度不变
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )

    def forward(self, inputs):
        """
        inputs: [batch_size, seq_len, d_model] 
        """
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model).cuda()(output + residual)  # [batch_size, seq_len, d_model]

Encoder Layer

# 包含多头自注意力机制+前馈神经网络
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()  # 命名:编码器-自注意力
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        """
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        """
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)  # Q,K,V同源
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs = self.pos_ffn(enc_outputs)  # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn  # enc_outputs 的维度与 enc_inputs 维度相同

Encoder

# Encoder 部分包含三个部分:词向量embedding,位置编码,n_layers 层EncoderLayer(注意力层+FFN)
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])  # 使用ModuleList堆叠多个EncoderLayer

    def forward(self, enc_inputs):
        """
        enc_inputs: torch.Size([batch_size, src_len])
        """
        enc_outputs = self.src_emb(enc_inputs)  # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1)  # [batch_size, src_len, d_model]
        # 前面位置编码中的输入为[seq_len, batch_size, d_model],所以要transpose前两个维度
        # 经过位置编码后,保持输入输出维度不变
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)  # [batch_size, src_len, src_len]

        enc_self_attns = []
        for layer in self.layers:
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_self_attns.append(enc_self_attn)  # 列表,长度为 n_layers
        return enc_outputs, enc_self_attns

Decoder Layer

# 包含三个部分:掩码多头自注意力 + 编码-解码多头注意力 + FFN
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()  # 命名:解码-自注意力
        self.dec_enc_attn = MultiHeadAttention()  # 命名:解码-编码-注意力
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        """
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        return: dec_outputs 保持与 dec_inputs 维度相同
        """
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)  # Q,K,V同源
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]

        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        # Q来自解码器端经过掩码多头自注意力的输出, K、V来自经过6层编码层后的输出
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]

        dec_outputs = self.pos_ffn(dec_outputs)  # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

Decoder

# Decoder 部分包含三个部分:词向量embedding,位置编码,n_layers 层DecoderLayer
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        """
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        """
        dec_outputs = self.tgt_emb(dec_inputs)  # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda()  # [batch_size, tgt_len, d_model]


        # dec_self_attn_pad_mask 自注意力机制中的 pad 部分,这个是bool类型:
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda()  # [batch_size, tgt_len, tgt_len]

        # dec_self_attn_subsequence_mask 做自注意层的mask部分,即当前单词之后的单词看不到,使用一个上三角为1的矩阵
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda()  # [batch_size, tgt_len, tgt_len]

        # 两个矩阵相加,大于0的为1,不大于0的为0,为1的在之后就会被fill填充为无限小
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda()
        # 变成 bool 类型 [batch_size, tgt_len, tgt_len]

        # 生成交互注意力机制中的 mask 矩阵
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)  # [batch_size, tgt_len, src_len] [2, 6, 5]
        # 也就是说 自注意力层用的是 dec_self_attn_mask, 交互注意力层用的是 dec_enc_attn_mask

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask,
                                                             dec_enc_attn_mask)
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, n_heads, tgt_len, src_len]
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

Transformer

# 包含 编码层 + 解码层 + 线性层
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder().cuda()
        self.decoder = Decoder().cuda()
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()

    def forward(self, enc_inputs, dec_inputs):
        """
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        """
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len]
        # dec_enc_attn: [n_layers, batch_size, n_heads, tgt_len, src_len]
        dec_logits = self.projection(dec_outputs)  # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
        #  展平成目标词表长度,用于计算损失 (batch_size * tgt_len, tgt_vocab_size)

Model loss function optimizer

In the loss function, a parameter is set ignore_index=0because the index of the word "pad" is 0. After this setting, the calculation of the loss of "pad" will be ignored (because originally "pad" is meaningless and does not need to be calculated).

model = Transformer().cuda()
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 最后的softmax在这里,用于计算交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)  # 随机梯度下降

train

for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        # enc_inputs: [batch_size, src_len] 张量
        # dec_inputs: [batch_size, tgt_len]
        # dec_outputs: [batch_size, tgt_len]

        enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        # outputs: [batch_size * tgt_len, tgt_vocab_size]
        loss = criterion(outputs, dec_outputs.view(-1))  # dec_outputs变为[batch_size * tgt_len]
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Output result:

Epoch: 0001 loss = 1.058965
Epoch: 0002 loss = 0.938208
Epoch: 0003 loss = 0.738537
Epoch: 0004 loss = 0.628805
Epoch: 0005 loss = 0.472079
Epoch: 0006 loss = 0.394795
......

test

Breakpoints to observe the prediction process:

# 预测时,不知道目标序列输入。因此,尝试逐字生成目标输入,然后将其输入到Transformer中。
# 预测的时候编码器中,以start_symbol作为起始输入
# 之后每一轮输出的预测值作为下一轮的输入,直至预测出'.'的index停止
def greedy_decoder(model, enc_input, start_symbol):  # start_symbol=6,int
    """
    :param model: Transformer Model
    :param enc_input: The encoder input [1, src_len] 
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 6
    :return: The target input
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    # 经过编码器之后,enc_input:(1,src_len) -> enc_outpus:(1,src_len,512)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)  # tensor([])
    terminal = False
    next_symbol = start_symbol
    while not terminal:  # 循环 从 ["S"] 开始,词向量表索引是tensor(6)
        dec_input = torch.cat([dec_input.detach(), torch.tensor([[next_symbol]], dtype=enc_input.dtype).cuda()], -1)
        # shape/data: (1,1)/([[6]]) -> (1,2)/([[6,1]]) -> (1,3)/([[6,1,2]])/ ->...
        # 上一轮的预测值作为下一轮的输入
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        # 经过解码器之后,dec_outputs:(1,1,512)->(1,2,512)->(1,3,512)->...
        projected = model.projection(dec_outputs)  # (1,1,9)->(1,2,9)->(1,3,9)->...
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]  # 按照最后一维找出值最大的,即预测的字的索引
        # shape/data: (1,)/tensor([1]) -> (2,)/tensor([1,2]) -> (3,)/tensor([1,2,3]) ->...
        # [1] 指最后返回的是最大值位置的索引
        next_word = prob.data[-1]  # 选取prob的位置索引中最后一个数,tensor(1)->tensor(2)->tensor(3)->...
        next_symbol = next_word
        if next_symbol == tgt_vocab["."]:  # 直至是".",即词向量表是8的话就终止
            terminal = True
        print(next_word)  # tensor(1)->(2)->(3)-> (4)-> (8)
    return dec_input


# Test
enc_inputs, _, _ = next(iter(loader))  # (2,5)
enc_inputs = enc_inputs.cuda()
for i in range(len(enc_inputs)):  # 长为2
    greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])  # [[6,1,2,3,4]] 
    predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)  # 输入:shape(1,5) 预测:shape(5,9) 
    predict = predict.data.max(1, keepdim=True)[1]  # 找出最大值索引 (5,1)
    print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])

Output result:

tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(4, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 4, 0], device='cuda:0') -> ['i', 'want', 'a', 'beer', '.']
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(5, device='cuda:0')
tensor(8, device='cuda:0')
tensor([1, 2, 3, 5, 0], device='cuda:0') -> ['i', 'want', 'a', 'coke', '.']

I won’t put all the code here. Just copy and paste the code mentioned above (except the code for printing results such as position encoding) and it will run.

Guess you like

Origin blog.csdn.net/qq_45670134/article/details/128005237