对Transformer中Positional Encoding的理解

Positional Encoding


首先来看一下Transformer结构的结构:
在这里插入图片描述
Transformer是由Encoder和Decoder两大部分组成,首先对于文本特征,需要进行Embedding,由于transformer抛弃了Rnn的结构,不能捕捉到序列的信息,交换单词位置,得到相应的attention也会发生交换,并不会发生数值上的改变,所以要对input进行Positional Encoding。

Positional encoding和input embedding是同等维度的,所以可以将两者进行相加,的到输入向量
在这里插入图片描述
接下来看一些Positional Encoding的计算公式:
在这里插入图片描述
其中pos表示token在sequence中的位置,d_model表示词嵌入的维度,i则是range(d_model)中的数值,也就是说:对于单个token的d_model维度的词向量,奇数位置取cos,偶数位置取sin,最终的到一个维度和word embedding维度一样的矩阵,接下来可以看一下代码:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def get_positional_encoding(max_seq_len, embed_dim):
    # 初始化一个positional encoding
    # embed_dim: 字嵌入的维度
    # max_seq_len: 最大的序列长度
    positional_encoding = np.array([
        [pos / np.power(10000, 2 * i / embed_dim) for i in range(embed_dim)] if pos != 0 else np.zeros(embed_dim) for pos in range(max_seq_len)])

    positional_encoding[1:, 0::2] = np.sin(positional_encoding[1:, 0::2])  # dim 2i 偶数
    positional_encoding[1:, 1::2] = np.cos(positional_encoding[1:, 1::2])  # dim 2i+1 奇数
    return positional_encoding


positional_encoding = get_positional_encoding(max_seq_len=100, embed_dim=16)
plt.figure(figsize=(10, 10))
sns.heatmap(positional_encoding)
plt.title("Sinusoidal Function")
plt.xlabel("hidden dimension")
plt.ylabel("sequence length")
plt.show()

首先求初始向量:positional_encoding,然后对其奇数列求sin,偶数列求cos:
在这里插入图片描述
最终得到positional encoding之后的数据可视化:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_51756104/article/details/127191868