pytorch学习笔记三——transformer

预备知识

transformer经典论文:attention is all you need

代码学习逻辑:

  • 从整体到局部
  • 关注每一部分数据流向(输入->运算->输出 矩阵维度变换)

模型架构

模型架构:

①和②是对Encoder和Decoder部分的输入进行token层面的embedding编码,③和④是对embedding后的信息加上一个位置信息(是在数据层面直接加,并非向量拼接)。

经过③后得到的数据Encoder_X,分别作为kqv送入⑤多头注意力中,此时⑤就是一个自注意力(因为kqv一样),得到一个结果Y。Encoder_X和Y一同送入⑥中做残差连接和层归一,输出Y,将Y送入到⑦位置前馈网络中得到一个结果与Y一同送入⑧中再进行残差归一输出Encode_Y,输出结果为Encoder一个块的结果,整个transformer Encoder是由n个Encoder块组成,所以Encoder块输入输出维度必须相同。

数据进入Encoder后得到一串输出一同送入Decoder中,作为⑪多头注意力的kv。经过②和④后得到的Decoder_X输入到⑨中得到一个结果Y,这是一个带掩码的多头自注意力,要知道attention是没有时间信息的,所以每一个Decoder块做⑨运算时将Decode_X中后面的信息遮掩住。Decoder_X和Y进行⑩残差归一后的结果Y,作为q与Encoder_Y作为kv输入到⑪中,得到的结果与Y进行⑫残差归一,再通过⑬位置前馈后做最后一次⑭残差归一,输出Decoder_Y,最后一层Decoder得到的结果再进行一次全连接层输出。

在这里插入图片描述

关注数据流向:每次运算的输入输出维度相同

在这里插入图片描述

机器翻译任务

在这里插入图片描述

在这里插入图片描述

实现

主体部分

import math
import torch
from torch import nn
from d2l import torch as d2l
# 读取数据
batch_size, num_steps = 64, 10  # 每一批数据64个句子,每个句子长度设置为10
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
# train_iter:X, X_valid_len, Y, Y_valid_len
# src_vocab, tgt_vocab:Encoder词表,Decoder词表

在这里插入图片描述

在这里插入图片描述

# 设置超参
num_hiddens = 32  # 主要:每个token进行embedding的长度  全局:确保数据每一步运算时的输入、输出维度相同
num_layers = 2  # Encoder和Decoder块的数目
lr, device = 0.005, d2l.try_gpu()  # 学习率0.005  训练使用GPU
ffn_num_input, ffn_num_hiddens = 32, 64  # 位置前馈网络输入层、隐藏层神经元个数(输入层个数是根据num_hiddens设置的)
num_heads, key_size, query_size, value_size = 4, 32, 32, 32  # 多头注意力:头数必须整除embedding的长度,kqv长度也是根据num_hiddens设置的
norm_shape, dropout = [32], 0.1  # 归一化参数
# 实例化模型,配置超参
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size,
                             num_hiddens, norm_shape, ffn_num_input,
                             ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size,
                             num_hiddens, norm_shape, ffn_num_input,
                             ffn_num_hiddens, num_heads, num_layers, dropout)
net = EncoderDecoder(encoder, decoder)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = MaskedSoftmaxCELoss()
net.train()

for batch in train_iter:
    optimizer.zero_grad()
    X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]  # X,Y:(64*10) X_valid_len,Y_valid_len:(64,)
    bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)  # (64,1)
    dec_input = d2l.concat([bos, Y[:, :9]], 1)  # concat:(64,1),(64,9)->(64,10)

    Y_hat, _ = net(X, dec_input, X_valid_len)  # net(encoder输入数据, decoder输入数据, encoder输入数据的真实长度)
    l = loss(Y_hat, Y, Y_valid_len)

    l.sum().backward()
    d2l.grad_clipping(net, 1)  # 梯度剪裁
    optimizer.step()
    # break

在这里插入图片描述

模型构建

Encoder-Decoder

class EncoderDecoder(nn.Module):
    """The base class for the encoder-decoder architecture."""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        res = self.decoder(dec_X, dec_state)
        return res

Encoder部分

class TransformerEncoder(d2l.Encoder):
    """在下面的转换器编码器实现中,我们堆叠 num_layers了上述EncoderBlock类的实例。
    由于我们使用值始终在 -1 和 1 之间的固定位置编码,因此我们将可学习输入嵌入的值乘以嵌入维度的平方根以重新缩放,
    然后再对输入嵌入和位置编码求和。"""
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()  # 每一个块都是一个有序的容器
        # 有 num_layers 个 transformerBlock
        for i in range(num_layers):
            self.blks.add_module(
                "block" + str(i),
                EncoderBlock(key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        # 由于位置编码pos_encoding值在-1和1之间,嵌入值乘以嵌入维度的平方根以重新缩放,然后再求和
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        # print(X.shape)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
        return X
class EncoderBlock(nn.Module):
    """EncoderBlock类包含两个子层:多头自注意和位置前馈网络,其中在两个子层周围采用残差连接,然后进行层归一化。"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(key_size, query_size,
                                                value_size, num_hiddens,
                                                num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                   num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        # 自注意力:三个X分别代表key value query
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

残差连接和层归一、位置前馈、多头注意力

class AddNorm(nn.Module):
    """使用残差连接和层归一化来实现该类。
    要求X,Y的大小与normalized_shape相同,以便在加法运算后输出张量也具有相同的形状。"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)
class PositionWiseFFN(nn.Module):
    """位置前馈网络: 使用相同的MLP转换所有序列位置的表示。这就是我们称之为 positionwise 的原因。
        在下面的实现中,X具有形状(批量大小、时间步长或序列长度、隐藏单元数量或特征维度)
        的输入将被两层 MLP 转换为形状(批量大小、时间步数, ffn_num_outputs)。"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)

        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)  # 缩放点积模型 求注意力分数
        # bias=False 不计算偏差目的:特征维度转换 (64,10,32)->(64,10,32)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # 保留分歧:每个token的embedding按头的数量切分,还是按头的数量复制?不能拿人的规则限制它
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_heads,
                                                 dim=0)

        output = self.attention(queries, keys, values, valid_lens)

        output_concat = transpose_output(output, self.num_heads)
        # 拼接后再加一层线性映射的目的:就是给多个特征再进行一次加权融合

        return self.W_o(output_concat)
def transpose_qkv(X, num_heads):
    # print("self.W_ shape:", X.shape)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # print("reshape(X.shape[0], X.shape[1], num_heads, -1) shape:", X.shape)

    X = X.permute(0, 2, 1, 3)
    # print("permute: ", X.shape)

    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """ 逆转 `transpose_qkv` 函数的操作 """
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

Decoder部分

class TransformerDecoder(d2l.AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()  # 每一个块都是一个有序的容器
        for i in range(num_layers):
            self.blks.add_module(
                "block" + str(i),
                DecoderBlock(key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, i))
        # 英法翻译,所以最后一层每个元素对应输出一个vocab_size
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
        return self.dense(X), state

    # 广泛用于类的定义中,把方法变成属性,保证对参数进行必要的检查,减少程序运行时出错的可能性。
    @property
    def attention_weights(self):
        return self._attention_weights

class DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_size, query_size,
                                                 value_size, num_hiddens,
                                                 num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = d2l.MultiHeadAttention(key_size, query_size,
                                                 value_size, num_hiddens,
                                                 num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                   num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        # state中有三组数据,0:Encoder的输出 1:Encoder的长度 2:过去的记忆
        # query是一个一个输入的,这样输入第i个的时候,在key和value中需要存i和之前的所有
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 在train期间,任何输出序列的所有标记都会同时处理,因此初始化时state[2][self.i]为None。
        # 在predict期间,通过标记解码任何输出序列标记时,state[2][self.i]包含在第i个块中直到当前时间步的解码输出的表示
        if state[2][self.i] is None:
            key_values = X
        else:
            # key_values之前的东西不断存入
            key_values = torch.cat((state[2][self.i], X), axis=1)

        state[2][self.i] = key_values
        # 训练到i的时候把后面的遮住,所以需要dec_valid_lens标记长度,测试的时候不需要
        # dec_valid_lens为了在train时不关注后面的内容
        if self.training:
            batch_size, num_steps, _ = X.shape
            # dec_valid_lens的形状:(batch_size, num_steps),其中每一行是 [1, 2, ..., num_steps]
            dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # Encoder-decoder attention: enc_outputs的形状:(batch_size num_steps num_hiddens)
        # key和value均来自Encoder Output
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)

        return self.addnorm3(Z, self.ffn(Z)), state

题外话

一个大佬分享的各种编程书籍:https://github.com/XiangLinPro/IT_book#%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E6%A1%86%E6%9E%B6-pytorch

猜你喜欢

转载自blog.csdn.net/qq_41754907/article/details/121844401