背景の紹介
トランスフォーマー アーキテクチャでは、単語ベクトルの入力に、元の単語に対応する位置情報をトレーニング用のモデルへの入力として追加する必要がありますが、特定の位置エンコーディングを実装するにはどうすればよいですか? このブログでは、対応する手順を共有します。
位置エンコード式
単語ベクトルの位置をエンコードするにはさまざまな方法がありますが、ここでは三角関数を使用した位置エンコードの公式を紹介します。
PEは位置埋め込み位置符号化を意味し、posは単語の位置と単語ベクトルの次元を表し、iは単語ベクトルのi番目の次元を表す。
次に、次の式に従って位置エンコーディングのコードを実装します。
コード
環境依存ライブラリ
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
位置エンコーディング情報を取得する関数を定義する
def generate_word_embeding(max_len,d_model):
# 初始化位置信息
pos = torch.arange(max_len).unsqueeze(1)
# 初始化位置编码矩阵
result = torch.zeros(max_len,d_model)
# 获得公式对应的值
coding = torch.exp(torch.arange(0,d_model,2)*(-math.log(10000.0))/d_model)
result[:,0::2] = torch.sin(pos*coding)
result[:,1::2] = torch.cos(pos*coding)
# 为了与原编码直接相加,格式为[B,seq_len,d_model],需要再增加一个维度
return result.unsqueeze(0)
max_len が 100、d_model が 20 であると仮定すると、pos の次元は [100,1]、result の次元は [100,20]、コーディングの次元は [1,d_model/2]、result[: ,0::2] は、式の PE(pos,2i) に対応する、結果の列 0 から始まる 1 つおきの列に値を割り当てることを指します。同様に、result[:,1::2] は PE に対応します。式の (pos)、2i+1)
位置情報をエンコードした情報の可視化
位置エンコーディング情報を視覚化して、より直感的に感じられるようにします
d = 6
pos_code = generate_word_embeding(100,d)
print(pos_code.shape)
plt.plot(np.arange(100),pos_code[0,:,0:d])
plt.legend(['dim=%d'%p for p in range(d)])
plt.show()
単語の時間長を 6 に設定し、各次元の位置コーディング情報を対応する時間シーケンスで表示します。
各時系列位置は、各次元の三角関数の変換規則に対応していることがわかり、学習用モデルに入力した後、学習を通じて対応する位置の知識を得ることができます。
皆さんもぜひ議論や交流をしてください〜