CogViewにおけるネットワーク構造の全体構築

私は駆け出しの新人ですが、学んだことをメモのように記録していきたいと思っていますし、同じく始めようとしている人の手助けになればと思っています。

目次

1. 構造図

2、コード分析

1、__初期化__

(1) パラメータの設定

(2)単語の埋め込み(並列)

(3)変圧器

2、前進

(1)単語の埋め込み(並列)

(2)変圧器

(3)並列ロジット

(4) シリアルまたはパラレル出力


1. 構造図


2、コード分析

コードのこの部分は、model/gpt2_modeling.py にあります。

1、__初期化__

(1) パラメータの設定

  •  num_layers:transformerLayer のレイヤー数。
  • vocab_size: 辞書のサイズ;
  • hidden_​​size: 入力レイヤーのサイズ;
  • num_attention_heads: アテンションヘッドの数。
  • embedding_dropout_prob: 埋め込みのドロップアウト確率;
  • attention_dropout_prob: 自己注意力のドロップアウト確率;
  • Output_dropout_prob: 出力ドロップアウト確率;
  • max_sequence_length: 最大シーケンス長 (毎回読み取られるシーケンス長);
  • Checkpoint_activations: チェックポイントのアクティブ化を有効にするかどうか。
  • Checkpoint_num_layers: チェックポイント層の番号。
  • Parallel_output: 出力はシリアルまたはパラレルです。
  • query_window: スパース処理のウィンドウ サイズ。
  • key_window_times: スパース処理でのウィンドウの数を調整するために使用されます。
  • num_pivot: スパース処理のトークンの総数。
class GPT2Model(torch.nn.Module):
    """GPT-2 Language model.

    The output of the forward method are the logits (parallel or
    serial depending on the `parallel_output` flag.
    """

    def __init__(self,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 max_sequence_length,
                 max_memory_length,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 parallel_output=True,
                 query_window=128,
                 key_window_times=6,
                 num_pivot=768
                 ):

        super(GPT2Model, self).__init__()

        self.parallel_output = parallel_output

        init_method = init_method_normal(std=0.02)#初始化方法为高斯分布(均值为0,方差为0.02)

(2)単語の埋め込み(並列)

        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
            vocab_size, hidden_size, init_method=init_method)

詳細については、「CogView の Word embeddings (Parallel)」を参照してください- プログラマー募集

(3)変圧器

        # Transformer
        self.transformer = mpu.GPT2ParallelTransformer(num_layers,
                                                       hidden_size,
                                                       num_attention_heads,
                                                       max_sequence_length,
                                                       max_memory_length,
                                                       embedding_dropout_prob,
                                                       attention_dropout_prob,
                                                       output_dropout_prob,
                                                       checkpoint_activations,
                                                       checkpoint_num_layers,
                                                       query_window=query_window,
                                                       key_window_times=key_window_times,
                                                       num_pivot=num_pivot
                                                       )

詳細については、CogView の Transformer_ttya のブログを参照してください- CSDN ブログ

2、前進

    def forward(self, input_ids, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems):

(1)単語の埋め込み(並列)

形状は (b、s、h) です。

补:b ——バッチサイズ;s——シーケンス長;h——hidden_​​size;

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings

(2)変圧器

        # Transformer.
        transformer_output = self.transformer(embeddings, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems)
        logits, *hidden_layers = transformer_output#logits为output;*hidden_layers为*mem

(3)並列ロジット

        # Parallel logits.
        logits_parallel = mpu.copy_to_model_parallel_region(
            logits)#传递到模型并行区域
        logits_parallel = F.linear(logits_parallel,
                                   self.word_embeddings.weight)#线性变化

最終的な形状は (b,s,h)*(v/p,h)^T=(b,s,v/p) です。

v——vocab_size;p——パーティション数;

(4) シリアルまたはパラレル出力

        if self.parallel_output:#并行
            return (logits_parallel, *hidden_layers)

        return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)#串行

皆さんもコメント欄で批判や修正を歓迎します、ありがとう~

おすすめ

転載: blog.csdn.net/weixin_55073640/article/details/126585021