Transformer基本コード実装(入力部分実装)

    入力部分は大きく分けて「Text Embedded Layer」(Embedings)と「Positional Encoding」の2つに分けられます。

1.
    テキスト埋め込み層のこの層の目的は、テキスト語彙のデジタル表現をベクトル表現に変換することです。

class Embeddings(nn.Module):
    def __init__(self,d_model,vocab):
        super(Embeddings,self).__init__()
        self.lut = nn.Embedding(vocab,d_model)
        self.d_model = d_model

    def forward(self,x):
        return self.lut(x)*math.sqrt(self.d_model)

    d_model は、単語の埋め込み次元、つまり、単語が表す必要がある次元ベクトルの数を表します。
    vocab は、語彙内の単語の総数を表します。
    組み込みの Embedding モジュールを直接使用してデータを処理し、最後に d_model を使用してスケーリングします。
    テストコード:

d_model = 512
vocab = 1000
x=Variable(torch.LongTensor([[100,2,29,165],[1,6,8,7]]))

emb=Embeddings(d_model,vocab)
embr=emb(x)

ここに画像の説明を挿入
2.位置エンコーダー
    のベクトル化されたテキストには位置情報がありませんが、位置エンコーダーを使用すると、ベクトルに位置情報を反映させることができます。

class PositionalEncoding(nn.Module):
    def __init__(self,d_model,dropout,max_len=5000):
        super(PositionalEncoding,self).__init__()

        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len,d_model)

        position = torch.arange(0,max_len).unsqueeze(1)

        div_term = torch.exp(torch.arange(0,d_model,2)*(-math.log(10000.00)/d_model))

        pe[:,0::2]=torch.sin(position*div_term)

        pe[:,1::2]=torch.cos(position*div_term)

        pe=pe.unsqueeze(0)

        self.register_buffer('pe',pe)

    def forward(self,x):
        x=x+Variable(self.pe[:,:x.size(1)],requires_grad = False)
        return self.dropout(x)

    位置情報は、まず位置エンコード行列と絶対位置行列を必要とします。
    位置エンコード行列は最初に 0 に設定され、そのサイズは (文の長さ、単語の次元) です。絶対位置行列のサイズは (0, 文長) で、その値は連続する自然数です. unsqueeze メソッドを使用して、行列の次元を広げて (文長, 1) 行列にします。
    絶対位置行列を (文の長さ、単語の次元) のサイズに変更したいので、変換行列 div_term が必要です。ジャンプは div_term の初期化時に使用されます。これは、奇数と偶数に従って初期化することで、sin と cos を使用して位置を区別できるようにします。
    この処理の後、2 次元の行列が得られますが、これを埋め込みで出力するには、行列の次元を増やし、最後にコードをモデル バッファーとして登録する必要があります。この行列にはパラメーターがなく、その後の更新が必要ないためです。 . 最初から 直せます。
    3 桁のテンソルの 2 番目の次元 (文の最大長の次元) をスライスして、入力 X の 2 番目の次元と同じにするために、データの操作も必要です。文章の長さなどにも合わせます。
ここに画像の説明を挿入

おすすめ

転載: blog.csdn.net/daweq/article/details/129803994