正弦、余弦三角函数位置编码讲解、代码实现



一、正弦、余弦三角函数位置编码讲解

在Transformer中,位置编码是为了引入位置信息,而位置编码的形式通常是一个正弦函数和一个余弦函数的组合,公式如下:
计算公式

其中,PE(pos,i)​表示位置编码矩阵中第 pos 个位置,第 i 个维度的值;dmodel​表示模型嵌入向量的维度;i表示位置编码矩阵中第 i 个维度的值。这种位置编码方式可以引入位置信息,使得Transformer模型可以处理序列数据。
假设序列长度为4,位置编码维度为6,则位置编码矩阵如下:
在这里插入图片描述
其中三角函数括号中的部分可以由*号拆分成两部分,第一部分可以理解为x,第二部分可以理解为周期(普通的三角函数sin(2ΠX)的周期T为2Π,X为因变量)。
按列分析:如dim0这一列周期T为在这里插入图片描述
X为0~3的一个周期为定值的三角函数;
按行分析
如pos0这一行中,周期每两个元素变化一次,X为递增数列;所以按行看每个pos的位置编码是一个变周期(T)的三角函数;

二、代码实现

代码如下(示例)
1、实现上表中的矩阵:

import torch
def creat_pe_absolute_sincos_embedding(n_pos_vec, dim):
  assert dim % 2 == 0, "wrong dim"
  position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)

  omega = torch.arange(dim//2, dtype=torch.float)
  omega /= dim/2.
  omega = 1./(10000**omega)

  sita = n_pos_vec[:,None] @ omega[None,:]
  emb_sin = torch.sin(sita)
  emb_cos = torch.cos(sita)

  position_embedding[:,0::2] = emb_sin
  position_embedding[:,1::2] = emb_cos

  return position_embedding

2、初始化序列长度和位置编码的维度,并计算位置编码矩阵:

n_pos = 512
dim = 768
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
pe = creat_pe_absolute_sincos_embedding(n_pos_vec, dim)
print(pe)
tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2843e-01,  ...,  1.0000e+00,
          1.0243e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.2799e-01,  ...,  1.0000e+00,
          2.0486e-04,  1.0000e+00],
        ...,
        [ 6.1950e-02,  9.9808e-01,  5.3552e-01,  ...,  9.9857e-01,
          5.2112e-02,  9.9864e-01],
        [ 8.7333e-01,  4.8714e-01,  9.9957e-01,  ...,  9.9857e-01,
          5.2214e-02,  9.9864e-01],
        [ 8.8177e-01, -4.7168e-01,  5.8417e-01,  ...,  9.9856e-01,
          5.2317e-02,  9.9863e-01]])

3、按行对位置编码矩阵进行可视化:

# 不同pos
import matplotlib.pyplot as plt
x = [i for i in range(dim)]
for index, item in enumerate(pe):
  if index % 50 != 1:
    continue
  y = item.tolist()
  plt.plot(x, y, label=f"数据 {index}")
  plt.show()

以50为间隔打印,由于序列长度为512,所以可以打印出11个pos位置的曲线,下图为pos0,pos250,pos500处的位置编码曲线:
在这里插入图片描述

4、按列对位置编码矩阵进行可视化:

# 不同dim
x = [i for i in range(n_pos)]
for index, item in enumerate(pe.transpose(0, 1)):
  if index % 50 != 1:
    continue
  y = item.tolist()
  plt.plot(x, y, label=f"数据 {index}")
  plt.show()

以50为间隔打印,由于序列长度为768,所以可以打印出16个pos位置的曲线,下图为dim0,dim350,dim750处的位置编码曲线:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/Brilliant_liu/article/details/135033645