Construcción general de la estructura de red en CogView

Soy un novato al principio, espero registrar lo que he aprendido como tomar notas, y también espero ayudar a las personas que también están comenzando.

Tabla de contenido

1. Diagrama de construcción

Dos, análisis de código

1,__inicio__

(1) Configuración de parámetros

(2) incrustaciones de palabras (paralelo)

(3) Transformador

2, adelante

(1)Incrustaciones de palabras (paralelas)

(2) Transformador

(3) Logitos paralelos

(4) Salida serial o paralela


1. Diagrama de construcción


Dos, análisis de código

Esta parte del código está en model/gpt2_modeling.py

1,__inicio__

(1) Configuración de parámetros

  •  num_layers: el número de capas de transformerLayer;
  • vocab_size: tamaño del diccionario;
  • hidden_size: tamaño de la capa de entrada;
  • num_attention_heads: el número de cabezas de atención;
  • embedding_dropout_prob: probabilidad de abandono de incrustación;
  • tention_dropout_prob: probabilidad de abandono de la autoatención;
  • output_dropout_prob: probabilidad de abandono de la salida;
  • max_sequence_length: longitud máxima de secuencia (longitud de secuencia leída cada vez);
  • checkpoint_activations: si habilitar la activación del punto de control;
  • checkpoint_num_layers: número de capa del punto de control;
  • salida_paralela: la salida es en serie o en paralelo;
  • query_window: tamaño de ventana para procesamiento escaso;
  • key_window_times: se utiliza para ajustar el número de ventanas en el procesamiento disperso;
  • num_pivot: el número total de tokens en procesamiento disperso;
class GPT2Model(torch.nn.Module):
    """GPT-2 Language model.

    The output of the forward method are the logits (parallel or
    serial depending on the `parallel_output` flag.
    """

    def __init__(self,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 max_sequence_length,
                 max_memory_length,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 parallel_output=True,
                 query_window=128,
                 key_window_times=6,
                 num_pivot=768
                 ):

        super(GPT2Model, self).__init__()

        self.parallel_output = parallel_output

        init_method = init_method_normal(std=0.02)#初始化方法为高斯分布(均值为0,方差为0.02)

(2) incrustaciones de palabras (paralelo)

        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
            vocab_size, hidden_size, init_method=init_method)

Consulte Incrustaciones de Word (paralelas) en CogView para obtener más detalles - Se busca programador

(3) Transformador

        # Transformer
        self.transformer = mpu.GPT2ParallelTransformer(num_layers,
                                                       hidden_size,
                                                       num_attention_heads,
                                                       max_sequence_length,
                                                       max_memory_length,
                                                       embedding_dropout_prob,
                                                       attention_dropout_prob,
                                                       output_dropout_prob,
                                                       checkpoint_activations,
                                                       checkpoint_num_layers,
                                                       query_window=query_window,
                                                       key_window_times=key_window_times,
                                                       num_pivot=num_pivot
                                                       )

Consulte el blog de Transformer_ttya en CogView para obtener más detalles - Blog de CSDN

2, adelante

    def forward(self, input_ids, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems):

(1)Incrustaciones de palabras (paralelas)

La forma es (b, s, h)

补:b ——tamaño del lote ; s——longitud de la secuencia; h——tamaño_oculto;

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings

(2) Transformador

        # Transformer.
        transformer_output = self.transformer(embeddings, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems)
        logits, *hidden_layers = transformer_output#logits为output;*hidden_layers为*mem

(3) Logitos paralelos

        # Parallel logits.
        logits_parallel = mpu.copy_to_model_parallel_region(
            logits)#传递到模型并行区域
        logits_parallel = F.linear(logits_parallel,
                                   self.word_embeddings.weight)#线性变化

La forma final es (b,s,h)*(v/p,h)^T=(b,s,v/p)

v——vocab_size;p——número de particiones;

(4) Salida serial o paralela

        if self.parallel_output:#并行
            return (logits_parallel, *hidden_layers)

        return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)#串行

Todos son bienvenidos a criticar y corregir en el área de comentarios, gracias~

Supongo que te gusta

Origin blog.csdn.net/weixin_55073640/article/details/126585021
Recomendado
Clasificación