Transformer の Attendance (注意メカニズム) について少し調べてみる


要約: この記事では、Transformer のアテンション メカニズムについて少し調査してみます。さらに 6 つの問題について詳しく説明します。


✅ NLP リサーチ 1 出場者の学習ノート

はじめに: Xiao Wang、NPU、2023 年生、コンピュータ技術
研究の方向性: テキスト生成、要約生成



1. なぜこのブログを書くのか?

冗談:トランスフォーマーモデルの魔改造を最近やっていたのですが、しばらくやっていなかったのでアテンションの内容を覚えていないのですが、今日は復習してみましょう。6 つの質問を使用して、注意メカニズムの詳細な分析に焦点を当てます。

● 次のコードは T5-base の Attendance のソース コードです。これについては後で詳しく説明し、分析します。多くの変数は実際には使用されず、長すぎるため、T5Attention コードの短縮バージョンを後ろに掲載しました (実際の効果には影響しません)。

class T5Attention(nn.Module):
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
        super().__init__()
        self.is_decoder = config.is_decoder
        self.has_relative_attention_bias = has_relative_attention_bias
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
        self.relative_attention_max_distance = config.relative_attention_max_distance
        self.d_model = config.d_model
        self.key_value_proj_dim = config.d_kv
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
        self.inner_dim = self.n_heads * self.key_value_proj_dim

        # Mesh TensorFlow initialization to avoid scaling before softmax
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)

        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
        self.pruned_heads = set()
        self.gradient_checkpointing = False

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.inner_dim = self.key_value_proj_dim * self.n_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        """
        relative_buckets = 0
        if bidirectional:
            num_buckets //= 2
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
        else:
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
        relative_position_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)
        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
        )

        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
        return relative_buckets

    def compute_bias(self, query_length, key_length, device=None):
        """Compute binned relative position bias"""
        if device is None:
            device = self.relative_attention_bias.weight.device
        context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
        return values

    def forward(
        self,
        hidden_states,
        mask=None,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
    ):
        """
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        """
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length

        if past_key_value is not None:
            assert (
                len(past_key_value) == 2
            ), f"past_key_value should have 2 past states: keys and values. Got {
      
       len(past_key_value)} past states"
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

        def shape(states):
            """projection"""
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            """reshape"""
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
            """projects hidden states correctly to key/query states"""
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))

            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )

        # compute scores
        scores = torch.matmul(
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

        if position_bias is None:
            if not self.has_relative_attention_bias:
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
                if self.gradient_checkpointing and self.training:
                    position_bias.requires_grad = True
            else:
                position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)

            # if key and values are already calculated
            # we want only the last query position bias
            if past_key_value is not None:
                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

            if mask is not None:
                mask = mask.to('cuda')
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)

        if self.pruned_heads:
            mask = torch.ones(position_bias.shape[1])
            mask[list(self.pruned_heads)] = 0
            position_bias_masked = position_bias[:, mask.bool()]
        else:
            position_bias_masked = position_bias

        scores += position_bias_masked
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)

        # Mask heads if we want to
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask

        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)

        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

        if output_attentions:
            outputs = outputs + (attn_weights,)
        return outputs

● T5Attention コードの短縮版 (以下の説明には影響しません)。以下で説明する「コード」はこれを指します。

class T5Attention(nn.Module):
    def __init__(self, config: T5Config):
        super().__init__()
        self.is_decoder = config.is_decoder
        self.d_model = config.d_model
        self.key_value_proj_dim = config.d_kv
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
        self.inner_dim = self.n_heads * self.key_value_proj_dim

        # Mesh TensorFlow initialization to avoid scaling before softmax
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)

    def forward(
        self,
        hidden_states,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
    ):
        def shape(states):
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
           
        # Input is (batch_size, seq_length, dim)
        batch_size, seq_length = hidden_states.shape[:2]

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = shape(self.k(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
        value_states = shape(self.v(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # compute scores
        scores = torch.matmul(
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)

        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)

        return attn_output 

2. 魂の質問に答えていただけますか?

  1. 注意の入力は何ですか? アウトプットとは何ですか?
  2. アテンションの入力と「Q、K、V」の関係は何ですか?
  3. 注意の Q、K、V は何を意味しますか?
  4. アテンションの計算プロセスは何ですか?
  5. 長い注意力は何の役に立つのでしょうか?
  6. クロスアテンションを導入してください。

注: 以下のモデルはすべてconfigT5 モデルのデフォルト構成であり、説明の便宜のために書かれています。


1. アテンションの入力は何ですか? 出力は何ですか?

  回答: 注意 の入力は text features でhidden_states、これは正式にはバッチ ( batch_size) 特徴テンソルです。バッチ (batch_size) が4、最大テキスト処理長 ( max_length) が512、フィーチャの寸法 ( d_model) が の場合768hidden_states入力の形状は です(4, 512, 768)出力も、出力と同じ形状のテンソルです。


2. アテンションの入力と「Q、K、V」の関係は何ですか?

  回答: まずは「入力と3つの関係」について説明します。上記のコードを次のセクションに切り出します。self.qself.kself.vそれぞれ3 つの異なるものnn.Linear(self.d_model, self.inner_dim, bias=False)、つまり線形層であるため、その中にself.d_model一般的な特徴次元768(つまり、各アテンション モジュールによって送信される際に統一する必要がある特徴次元) があり、 はself.inner_dimアテンションの内部特徴次元です768(つまり、アテンション モジュールで内部的に使用されるフィーチャ ディメンション)。したがって、nn.Linear(self.d_model, self.inner_dim, bias=False)バイアスを含めて 768×768 のニューラル ネットワークになりますnn.Linear(768, 768, bias=False)
  したがって、入力と「Q、K、V」の間にはマッピング関係があります。入力 (入力) は、3 つの異なる 768 × 768 ニューラル ネットワークを通じて 3 つの異なる特徴テンソルにマッピングされ、これら 3 つの異なる特徴テンソルの形状は依然として(4, 512, 768)(ここでは最初の質問を引き継ぎます) 「If」の場合です。次のいくつかの質問)。

# get query states
query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

# get key/value states
key_states = shape(self.k(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
value_states = shape(self.v(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

  ここでは、それが何に使用されるかを簡単に説明しますshape()テンソルを、中身は変えずに形状だけを「多頭テンソル」に変換する「形状コンバータ」です。平たく言えば、テンソルを 1 次元だけ拡張し、テンソル内のすべての項目をこの次元に割り当てることです。

def shape(states):
	return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

  例: 入力が(4, 512, 768)形状のテンソルである場合、つまりhidden_states= Tensor{(4, 512, 768)} となります。Q== Tensor{(4, 512, 768)}の場合self.q(hidden_states)(注self.q(hidden_states)hidden_states)。するとquery_states= shape(Q)= Q.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)= Q.view(4, -1, 12, 64).transpose(1, 2)= Tensor{(4, 512, 12, 64)}. transpose(1, 2)= Tensor{(4, 12, 512, 64)}
  このうち、self.n_headsマルチヘッド アテンション メカニズムのヘッド数を表し、T5モデルのデフォルトは です12self.key_value_proj_dim各ヘッドのフィーチャー寸法を示します。これは、T5 モデルのデフォルトです64


3. 「注意」の Q、K、V の意味は何ですか?

  回答: 説明の便宜上、Li Honyi のコースのセルフアテンションの例から始めましょう。最初に計算プロセスを見てみましょう: 入力が 4 つの単語のみからなる文の場合: 「Hello World」、単語を実行し
  ますセグメンテーション (トークナイザー) は ["you"、"好"、"世"、"世界"] を取得します。(説明の便宜上、後で Embedding 操作を省略し、直接 Attending を実行します。) そして、q 1 q^1q1 =self.q(「あなた」)k 1 k^1k1 =self.k(「あなた」)v 1 v^1v1 =self.v(「あなた」);q 2 q^2q2 =self.k(「良い」)k 2 k^2k2 =self.k("良い")v 2 v^2v2 =self.v("good"); など...
  その後、α 1, 1 \alpha_{1,1}ある1、1 _ _= q 1 × k 1 q^1\times k^1q1×k1 ;α 1 , 2 \alpha_{1,2}ある1、2 _ _= q 1 × k 2 q^1\times k^2q1×k2 ; など... (α x, 任意 \alpha_{x, 任意}あるx オプション全てNo.xxx語が処理されます。また、従来の Attendance もd \sqrt{d}d 説明の便宜上、ここでは動作を省略しています次の式が追加されます)
  、その後α 1 , 1 \alpha_{1,1}ある1、1 _ _α 1 , 2 \alpha_{1,2}ある1、2 _ _α 1 , 3 \alpha_{1,3}ある1、3 _ _およびα 1 , 4 \alpha_{1,4}ある1、4 _ _正規化、つまりnn.functional.softmax()コード内の演算を実行し、α ^ 1 , 1 \hat \alpha_{1,1}を取得します。ある^1、1 _ _= α 1 , 1 ∑ in = 4 α 1 , i \frac{\alpha_{1,1}}{\sum_{i}^{n=4}\alpha_{1,i}}n = 4ある1 ある1、1 _ _; α ^ 1 , 2 \hat \alpha_{1,2}ある^1、2 _ _= α 1 , 2 ∑ in = 4 α 1 , i \frac{\alpha_{1,2}}{\sum_{i}^{n=4}\alpha_{1,i}}n = 4ある1 ある1、2 _ _; など... b 1 b^1
  を計算します。b1 =∑ jn = 4 α ^ 1 , j × vj \sum_{j}^{n=4}\hat \alpha_{1,j}\times v^jjn = 4ある^1 j×vj =α ^ 1 , 1 × v 1 + α ^ 1 , 2 × v 2 + α ^ 1 , 3 × v 3 + α ^ 1 , 4 × v 4 \hat \alpha_{1,1} \times v^ 1 + \hat \alpha_{1,2} \times v^2 + \hat \alpha_{1,3} \times v^3+ \hat \alpha_{1,4} \times v^4ある^1、1 _ _×v1+ある^1、2 _ _×v2+ある^1、3 _ _×v3+ある^1、4 _ _×v4
ここに画像の説明を挿入します
  上の計算はb 1 b^1b1回の計算。次にb 2 b^2を実行する必要があります。b2b3b^3b3b 4 b^4b4の計算。計算プロセスも同様ですが、次にb 1 b^1b1が計算されます: (上の図ではα ^ \hat \alpha注意してくださいある下図の^とα '' \alpha'ある'はすべて同じ意味を持つことを意味します下の写真には1 a_1ある1a2a_2ある2a3a_3ある3 a 4 a_4 ある4は、それぞれ「you」、「hao」、「shi」、「jie」という 4 つの単語として理解できます)
  α 2 , 1 \alpha_{2,1}ある2、1 _ _= q 2 × k 1 q^2\times k^1q2×k1 ;α 2 , 2 \alpha_{2,2}ある2、2 _ _= q 2 × k 2 q^2\times k^2q2×k2 ; など...
  α ^ 2 , 1 \hat \alpha_{2,1}ある^2、1 _ _= α 2 , 1 ∑ in = 4 α 2 , i \frac{\alpha_{2,1}}{\sum_{i}^{n=4}\alpha_{2,i}}n = 4ある2 ある2、1 _ _; α ^ 2 , 2 \hat \alpha_{2,2}ある^2、2 _ _= α 2 , 2 ∑ in = 4 α 2 , i \frac{\alpha_{2,2}}{\sum_{i}^{n=4}\alpha_{2,i}}n = 4ある2 ある2、2 _ _; など...
  b 2 b^2b2 =∑ jn = 4 α ^ 2 , j × vj \sum_{j}^{n=4}\hat \alpha_{2,j}\times v^jjn = 4ある^2 j×vj =α ^ 2 , 1 × v 1 + α ^ 2 , 2 × v 2 + α ^ 2 , 3 × v 3 + α ^ 2 , 4 × v 4 \hat \alpha_{2,1} \times v^ 1 + \hat \alpha_{2,2} \times v^2 + \hat \alpha_{2,3} \times v^3+ \hat \alpha_{2,4} \times v^4ある^2、1 _ _×v1+ある^2、2 _ _×v2+ある^2、3 _ _×v3+ある^2、4 _ _×v4最後に
ここに画像の説明を挿入します
  、いくつかの計算の後、b 1 b^{1}b1b 2 b^2b2b3b^3b3b 4 b^4b4はそれぞれ「you」、「good」、「world」、「world」の出力です。
  一見するとそう見えますが、各単語に対して 3 つの異なるニューラル ネットワーク マッピングを実行し、各単語の 3 つの異なるコピー「Q、K、V」を取得し、その「コピー Q」を追加しているだけではないでしょうか。各単語?それ自体と他の単語の「コピー K」を乗算し、その結果を正規化して「注意分布」を取得し、この「注意分布」とそれ自体と他の単語の「コピー V」を掛けます。つまり、各単語の「注目スコア」が得られます。
  さて、計算が終わったら、Q、K、Vの意味を深く分析してみましょう!
  ["you", "好", "世", "世界"] の "世" という単語の "注目スコア" を計算すると、次のように考えてみましょう。すると、Qは読んだ単語の「基本的な意味」を表し、「世界、世代、世紀、誕生、死、人間世界」などが自然に思い浮かび、Kは文全体(または記事全体)を表します。 ) 単語のより好ましい意味 -> 「世界、人間の世界」など (各単語の K は他の単語の Q と合わせて計算されるため、勾配が更新されると、K のネットワークは更新すると、「つながり」が少し深まります) 最後に、V をより抽象的に理解する必要があります。特定の人の心として理解できます (人によって経験が異なり、この単語を見たときに感じることも異なります)この言葉の「感情」は「当たり障りのない」かもしれないし、「暖かい」「愛しい」「気持ち悪い」などかもしれないし、とにかく薛志謙の「世界平和」という感情もあるのかもしれない。拡張された意味を表します。(この説明はあくまで筆者の個人的な理解です。後ほどクロスアテンションの仕組みで別の視点から理解していきます)
  実験は行っていませんが、あるテキスト データセットを学習する場合、「Q、K、V」の 3 つのネットワークはネットワークに似ているため、最も遅い「V のネットワーク」を更新するはずだと推測しています。バリューネットワーク」。また、「Q's network」は、同じ単語が異なる文章で異なる意味をもつことが非常に多いため、最も速く更新されます。
  最後に、ここで伏線を張りますが、「注目度分布」とは、Qの各単語とKの各単語の「相関スコア」のようなもので、相関が強いほどスコアが高くなります。これについては、後の計算プロセスでもう一度説明します。

  Q、K、V の意味を理解した上で、簡単な例を挙げてみましょう。T5 モデルに次の文を入力した場合 (ChatGPT にも同様に入力します)、モデルの最終出力はどうなるでしょうか? どうやって推理したんですか?

リンゴを買ったところですが、とてもおいしいと思いました。私は今何を食べましたか?

  明らかに、T5 モデルは「apple」と「it」という単語により多くの注意を払うことになり、「q と k」を掛けた「注意分布」は非常に大きくなります。 」は「リンゴ」とのつながりが最も強いので、答えは「リンゴ」です。


4. アテンションの計算プロセスは何ですか?

  回答:注意力の計算過程については「3.注意力のQ、K、Vの意味は?」で紹介しましたが、それはほんの一部です。完全な計算プロセスはコードにも依存します (すべての重要なステップを含む超簡略化バージョン)。

def forward(self, hidden_states, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False,):

	def shape(states):  # 作用: 将某一个张量的形状从 (batch_size, seq_length, d_model) 转化为 (batch_size, n_heads, seq_length, dim_per_head)
		return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
		
	def unshape(states):  # 作用: 将某一个张量的形状从 (batch_size, n_heads, seq_length, dim_per_head) 转化为 (batch_size, seq_length, d_model)
		return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
		
	batch_size, seq_length = hidden_states.shape[:2]  # hidden_states 是一个张量 {Tensor(4, 512, 768)}
	query_states = shape(self.q(hidden_states))  # 计算出 Q, 并转化为多头, 得到 {Tensor(4, 12, 512, 64)}
    key_states = shape(self.k(hidden_states))  # 计算出 K, 并转化为多头, 得到 {Tensor(4, 12, 512, 64)}
    value_states = shape(self.v(hidden_states))  # 计算出 V, 并转化为多头, 得到 {Tensor(4, 12, 512, 64)}
    scores = torch.matmul(query_states, key_states.transpose(3, 2))  # 计算 “注意力分布” {Tensor()}, {Tensor(4, 12, 512, 64)} × {Tensor(4, 12, 64, 512)} → {Tensor(4, 12, 512, 512)}
    attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)  # softmax操作(即归一化操作)
    attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)  # dropout操作(防止过拟合)
    attn_output = unshape(torch.matmul(attn_weights, value_states))  
    # 先进行矩阵相乘 {Tensor(4, 12, 512, 512)} × {Tensor(4, 12, 512, 64)} → {Tensor(4, 12, 512, 64)}, 再将多头特征融合起来. {Tensor(4, 12, 512, 64)} → {Tensor(4, 512, 768)}
    attn_output = self.o(attn_output)  # 最后经过一个线性层, 得到的结果仍然是一个张量 {Tensor(4, 512, 768)}
    
    return attn_output 

  知らせ!説明をわかりやすくするため、Attentionの仕組みに焦点を当てるため、上記のコードではマスク(Mask)の操作相対位置情報の埋め込み(position_bias)を省略しています。 。
  簡単に言うと、attention の計算はすべて行列 (テンソルに基づいているとも言えます) に基づいており、前に書いたように 1 つずつ計算されるわけではありません: α 1 , 1 \alpha_ { 1,1 }ある1、1 _ _= q 1 × k 1 q^1\times k^1q1×k1α 1 , 2 \alpha_{1,2}ある1、2 _ _= q 1 × k 2 q^1\times k^2q1×k2ですが、行列計算の方が便利で計算を高速化できる点を除けば、計算結果は同じです (ここで、d \sqrt{d}d 一般に、これはモデルのデフォルトのフィーチャー次元d_model、つまり768) です。

a = q × k ⊺ d = [ q 1 q 2 q 3 q 4 ] × [ k 1 k 2 k 3 k 4 ] d = [ a 1 , 1 a 1 , 2 ⋯ a 1 , 4 a 2 , 1 a 2 , 2 ⋯ a 2 , 4 ⋮ ⋮ ⋱ ⋮ a 4 , 1 a 4 , 2 ⋯ a 4 , 4 ] \boldsymbol{a}=\frac{\boldsymbol{q}\times \boldsymbol{k}^{\ intercal}}{\sqrt{d}}=\frac{\\ \left[ \begin{行列} q_1 \\ q_2 \\ q_3 \\ q_4\end{行列}\right] \times \left[ \begin{行列} k_1\,\, k_2 \,\,k_3 \,\,k_4\end{行列}\right] \\}{\sqrt{d}} = \left[ \begin{行列} a_{1,1 }& a_{1,2}& \cdots& a_{1,4}\\ a_{2,1}& a_{2,2}& \cdots& a_{2,4}\\ \vdots& \vdots& \ddots& \ vdots\\ a_{4,1}& a_{4,2}& \cdots& a_{4,4}\\ \end{行列} \right]ある=d q×k=d q1q2q3q4 ×[k1k2k3k4= ある1、1 _ _ある2、1 _ _ある4、1 _ _ある1、2 _ _ある2、2 _ _ある4、2 _ _ある1、4 _ _ある2、4 _ _ある4、4 _ _

α = ソフトマックス ( a ) = [ a 1 , 1 ' a 1 , 2 ' ⋯ a 1 , 4 ' a 2 , 1 ' a 2 , 2 ' ⋯ a 2 , 4 ' ⋮ ⋮ ⋱ ⋮ a 4 , 1 ' a 4 , 2 ′ ⋯ a 4 , 4 ′ ] , ai , j ′ = exp ⁡ ( ai , j ) ∑ j = 1 n = 4 exp ⁡ ( ai , j ) \boldsymbol{\alpha }=\text{softmax} \left( \boldsymbol{a} \right) =\left[ \begin{行列} a'_{1,1}& a'_{1,2}& \cdots& a'_{1,4}\\ a'_{2,1}& a'_{2,2}& \cdots& a'_{2,4}\\ \vdots& \vdots& \ddots& \vdots\\ a'_{4,1}& a '_{4,2}& \cdots& a'_{4,4}\\ \end{matrix} \right] ,\quad a'_{i,j}=\frac{\exp \left( a_{ i,j} \right)}{\sum_{j=1}^{n=4}{\exp}\left( a_{i,j} \right)}ある=ソフトマックス( a )= ある1、1 _ _ある2、1 _ _ある4、1 _ _ある1、2 _ _ある2、2 _ _ある4、2 _ _ある1、4 _ _ある2、4 _ _ある4、4 _ _ あるj=j = 1n = 4経験値( _ j)経験値( _ j)
b = α × v = [ a 1 , 1 ' a 1 , 2 ' ⋯ a 1 , 4 ' a 2 , 1 ' a 2 , 2 ' ⋯ a 2 , 4 ' ⋮ ⋮ ⋱ ⋮ a 4 , 1 ' a 4 , 2 ′ ⋯ a 4 , 4 ′ ] × [ v 1 v 2 v 3 v 4 ] = [ b 1 b 2 b 3 b 4 ] \boldsymbol{b}=\boldsymbol{\alpha }\times \boldsymbol{v } = \left[ \begin{行列} a'_{1,1}& a'_{1,2}& \cdots& a'_{1,4}\\ a'_{2,1}& a '_{2,2}& \cdots& a'_{2,4}\\ \vdots& \vdots& \ddots& \vdots\\ a'_{4,1}& a'_{4,2}& \cdots& a'_{4,4}\\ \end{行列} \right] \times \left[ \begin{行列} v_1 \\ v_2 \\ v_3 \\ v_4\end{行列}\right] = \left[ \begin{行列} b_1 \\ b_2 \\ b_3\\b_4\end{行列}\right]b=ある×v= ある1、1 _ _ある2、1 _ _ある4、1 _ _ある1、2 _ _ある2、2 _ _ある4、2 _ _ある1、4 _ _ある2、4 _ _ある4、4 _ _ × v1v2v3v4 = b1b2b3b4
コードでは、上記のb \boldsymbol{b}   が計算されますbdropoutの後には、実際には過剰適合を防ぐための操作がありますさらに、最終的な出力\text{output} は出力=self.o( b \boldsymbol{b}b )、ここでself.o()=nn.Linear(768, 768, bias=False)なぜこの線形レイヤーを追加する必要があるのでしょうか? おそらくこれはニューラル ネットワークの形而上学であり、層が増えると記憶はより深くなる可能性があるのでしょうか? ...
  OK、ここで先ほど敷いた伏線を明らかにします。行列α \boldsymbol{\alpha }αの形状から4×4、つまりQバージョン[「あなた」、「善」、「世界」、「傑」]とKバージョン[「あなた」であることがわかります。 ", " 「好」「世」「世界」の各単語間の「相関スコア」の一種。Q バージョンの「世」という単語 ["you"、"好"、"世"、"界"] が K バージョンの "世界" という単語 ["you"、"好"、"世"] の反対である場合", "世界"] 単語はモデル内で非常に関連しており、次に3, 4 ' a'_{3,4}ある3、4 _ _他の「注目度分布」と比べてスコアが大きくなります。


5. 長い注意力は何に役立ちますか?

  回答: まず、マルチヘッドの計算プロセスについて説明します。 注意: つまり、次元の特徴 (つまり [ 、 、 、 、 、7681.2563] -5.29340.0567何気なく-0.8004書き3.0503ました、特徴を表す数字の0.2502合計) を分割しますサブ特徴768に分割すると( =の場合、サブ特徴の次元は= =になります)、特徴番号を含むサブ特徴が得られます。   この原理は、コンピュータビジョン(CV)の分野における畳み込みニューラルカーネルの原理に似ていて、特徴量の計算をより「繊細」に、あるいは「より荒く」するためのものな気がします。3×3、5×5、1×1などの異なるコンボリューションカーネルサイズには、コンボリューションによって得られる特徴が異なります。明らかに、コンボリューションカーネルサイズが大きいほど、コンボリューションによって得られる特徴はより「正確」になります。カーネル サイズが小さいほど、畳み込みによって得られる特徴はより「一般的」になります。次に、雄牛が多ければ多いほど、テキストが「注意分布」計算を実行するときに、特定の単語をより多くの意味の層 (レイヤーの意味など)に分割し、同様に複数意味の層を持つ別の単語と組み合わせることができます。 .単語が掛け合わされることで、得られる複数の特徴がより「包括的」になります。NN12head_d_model768/1264N768/N
12


6. クロスアテンションを導入してください。

  回答: クロスアテンションはクロスアテンションのメカニズムです。前述のすべての例は、実際にはセルフ アテンション メカニズム、つまり、特定のテキストとそれ自体の間の「セルフ アテンション」の計算であり、あるテキストと別のテキストの間の「相互注意」の計算には関与しません。
  まず、クロスアテンションの計算プロセスを明確にしましょう
  次の文①を使用して文②クロスアテンションを実行するとします(注!!!順序は非常に重要です。最初に来た人が前者で、最後に来た人が後者です) 、誰が正しいか、順序が重要です)。

① 李華さんは英語の試験に落ちて、クラスメートと遊びに行ったと母親に話しました。
② シャオミンはスーパーの入り口までリーファに会いに行き、たくさん話をしました。

  次に、q \boldsymbol qqself.q()は、k \boldsymbol kによって得られた文①です。kv \boldsymbol vvself.k()は、それぞれとself.v()得られた文②ですこのうち、文①の長さは23、文②の長さは です21すると、
a = q × k ⊺ d = [ q 1 q 2 ⋮ q 23 ] × [ k 1 k 2 ⋯ k 21 ] d = [ a 1 , 1 a 1 , 2 ⋯ a 1 , 21 a 2 , 1 a 2 , 2 ⋯ a 2 , 21 ⋮ ⋮ ⋱ ⋮ a 23 , 1 a 23 , 2 ⋯ a 23 , 21 ] ∈ R 23 × 21 \boldsymbol{a}=\frac{\boldsymbol{q}\times \boldsymbol{ k}^{\intercal}}{\sqrt{d}}=\frac{\\ \left[ \begin{行列} q_1 \\ q_2 \\ \vdots \\ q_{23}\end{行列}\right ] \times \left[ \begin{行列} k_1\,\, k_2 \,\, \cdots \,\,k_{21}\end{行列}\right] \\}{\sqrt{d}} = \left[ \begin{行列} a_{1,1}& a_{1,2}& \cdots& a_{1,21}\\ a_{2,1}& a_{2,2}& \cdots& a_{ 2,21}\\ \vdots& \vdots& \ddots& \vdots\\ a_{23,1}& a_{23,2}& \cdots& a_{23,21}\\ \end{matrix} \right] \in \mathbb{R}^{23 \times 21}ある=d q×k=d q1q2q23 ×[k1k2k21= ある1、1 _ _ある2、1 _ _ある23、1 _ _ある1、2 _ _ある2、2 _ _ある23、2 _ _ある1、21 _ _ある2、21 _ _ある23、21 _ _ R23 × 21

α = ソフトマックス ( a ) = [ a 1 , 1 ' a 1 , 2 ' ⋯ a 1 , 21 ' a 2 , 1 ' a 2 , 2 ' ⋯ a 2 , 21 ' ⋮ ⋮ ⋱ ⋮ a 23 , 1 ' a 23 , 2 ′ ⋯ a 23 , 21 ′ ] , ai , j ′ = exp ⁡ ( ai , j ) ∑ j = 1 n = 21 exp ⁡ ( ai , j ) \boldsymbol{\alpha }=\text{softmax} \left( \boldsymbol{a} \right) =\left[ \begin{行列} a'_{1,1}& a'_{1,2}& \cdots& a'_{1,21}\\ a'_{2,1}& a'_{2,2}& \cdots& a'_{2,21}\\ \vdots& \vdots& \ddots& \vdots\\ a'_{23,1}& a '_{23,2}& \cdots& a'_{23,21}\\ \end{行列} \right] ,\quad a'_{i,j}=\frac{\exp \left( a_{ i,j} \right)}{\sum_{j=1}^{n=21}{\exp}\left( a_{i,j} \right)}ある=ソフトマックス( a )= ある1、1 _ _ある2、1 _ _ある23、1 _ _ある1、2 _ _ある2、2 _ _ある23、2 _ _ある1、21 _ _ある2、21 _ _ある23、21 _ _ あるj=j = 1n = 21経験値( _ j)経験値( _ j)
b = α × v = [ a 1 , 1 ' a 1 , 2 ' ⋯ a 1 , 21 ' a 2 , 1 ' a 2 , 2 ' ⋯ a 2 , 21 ' ⋮ ⋮ ⋱ ⋮ a 23 , 1 ' a 23 , 2 ′ ⋯ a 23 , 21 ′ ] × [ v 1 v 2 ⋮ v 21 ] = [ b 1 b 2 ⋮ b 23 ] ∈ R 23 × 1 \boldsymbol{b}=\boldsymbol{\alpha }\times \ボールドシンボル{v} = \left[ \begin{行列} a'_{1,1}& a'_{1,2}& \cdots& a'_{1,21}\\ a'_{2,1 }& a'_{2,2}& \cdots& a'_{2,21}\\ \vdots& \vdots& \ddots& \vdots\\ a'_{23,1}& a'_{23,2} & \cdots& a'_{23,21}\\ \end{行列} \right] \times \left[ \begin{行列} v_1 \\ v_2 \\ \vdots\\ v_{21}\end{行列} \right] = \left[ \begin{行列} b_1 \\ b_2 \\ \vdots \\b_{23}\end{行列}\right] \in \mathbb{R}^{23 \times 1}b=ある×v= ある1、1 _ _ある2、1 _ _ある23、1 _ _ある1、2 _ _ある2、2 _ _ある23、2 _ _ある1、21 _ _ある2、21 _ _ある23、21 _ _ × v1v2v21 = b1b2b23 R23 × 1

  最後に、dropout最後の線形層を通過した後に演算が実行され、出力が取得されますself.o()(注意!!!画像ではR 23 × 1 \mathbb{R}^{23 \times 1} と書きましたがR23 × 1ですが、より厳密な書き方はR 23 × 768 \mathbb{R}^{23 \times 768}Rb 1 b_1なので23 × 768b1b2b_2b2b3b_3b3そしてb 4 b_4b4実際、それらはすべて768次元ベクトルです。batch_size(= 4)を加えるとR 4 × 23 × 768 \mathbb{R}^{4\times 23 \times 768}となります。R4 × 23 × 768

  OK、計算プロセスは終了しました。では、クロスアテンション メカニズムの本質は何でしょうか? ちょうど「Q、K、V」の割り当てで。文①が文②に対してクロスアテンションを実行する場合(誰が誰に対して、前者と後者の関係は非常に重要です)、Q は文self.q()①から文のマッピング、K は文②から文のマッピング、V はself.k()文②から までself.v()のマッピング。Qはなぜ①の文からしか出てこないのでしょうか?文①を文②に対して分析するために使用するので、文①はクエリに似ており、文②は「辞書」に似ています。この辞書には多くの「キー」「値のペア」が含まれています。この「辞書」に目を通し、その「辞書」にある「鍵」を一つ一つ比較し、その「鍵」について思索する、一種の「探究」ともいえる問いを持ち込んでいます。 「価値」 - ここでの価値とは、辞書内の特定の単語の詳細な説明として理解できます。(この説明はあくまで筆者の個人的な理解であり、これまでの(Self-)Attendのメカニズムの理解とは少し異なります。)また
  、文①に対してクロスアテンションを行うために文②を使用したい場合、も可能です。ただ、途中で得られる「注目度分布行列」の大きさが変化してR 21 × 23 \mathbb{R}^{21 \times 23}となるだけです。R21 × 23、最終的にb \boldsymbol{b}bのサイズもR 21 × 1 \mathbb{R}^{21 \times 1}になりますR21 × 1

  最後に、クロスアテンション コードを見てみましょう。デモ用に T5 モデルのコードを使用してみましょう (簡易バージョン)。

def forward(self, hidden_states, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False,):

	def shape(states):  # 作用: 将某一个张量的形状从 (batch_size, seq_length, d_model) 转化为 (batch_size, n_heads, seq_length, dim_per_head)
		return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
		
	def unshape(states):  # 作用: 将某一个张量的形状从 (batch_size, n_heads, seq_length, dim_per_head) 转化为 (batch_size, seq_length, d_model)
		return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
		
	batch_size, seq_length = hidden_states.shape[:2]  # hidden_states 是一个张量 {Tensor(4, 512, 768)}
	query_states = shape(self.q(hidden_states))  # 计算出 Q, 并转化为多头, 得到 {Tensor(4, 12, 512, 64)}
    key_states = shape(self.k(key_value_states))  # 计算出 K, 并转化为多头, 得到 {Tensor(4, 12, 512, 64)}
    value_states = shape(self.v(key_value_states))  # 计算出 V, 并转化为多头, 得到 {Tensor(4, 12, 512, 64)}
    scores = torch.matmul(query_states, key_states.transpose(3, 2))  # 计算 “注意力分布” {Tensor()}, {Tensor(4, 12, 512, 64)} × {Tensor(4, 12, 64, 512)} → {Tensor(4, 12, 512, 512)}
    attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)  # softmax操作(即归一化操作)
    attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)  # dropout操作(防止过拟合)
    attn_output = unshape(torch.matmul(attn_weights, value_states))  
    # 先进行矩阵相乘 {Tensor(4, 12, 512, 512)} × {Tensor(4, 12, 512, 64)} → {Tensor(4, 12, 512, 64)}, 再将多头特征融合起来. {Tensor(4, 12, 512, 64)} → {Tensor(4, 512, 768)}
    attn_output = self.o(attn_output)  # 最后经过一个线性层, 得到的结果仍然是一个张量 {Tensor(4, 512, 768)}
    
    return attn_output 

  これは、前の (Self-)tention メカニズムのコードに似ていますが、 と だけがkey_states異なりvalue_statesます。


3. 補足事項

●文章に誤り・不適切な点やご質問がございましたら、お気軽にコメント・シェアしてください。

●先ほど落とし穴を買ってしまった、など。マスク演算は省略されd \sqrt dd 相対位置情報の埋め込み(position_bias)の説明も後ほどブログを書くときに追記しますので、今日はここまでです…

冗談: なんだか…大学院に入ったばかりでめんどくさくて、断続的にこのブログを書き終えるのに3日もかかってしまいました(笑)[/Manual Dog Head]…。


⭐️⭐️⭐️

おすすめ

転載: blog.csdn.net/Wang_Dou_Dou_/article/details/132739888