解析pytorch_transformer之model_bert.py

pytorch_transformers包含BERT, GPT, GPT-2, Transfo-XL, XLNet, XLM 等多个模型,并提供了27 个预训练模型。

对于每个模型,pytorch_transformers库里都对应有三个类:

  • model classes是模型的网络结构
  • configuration classes是模型的相关参数
  • tokenizer classes是分词工具,一般建议直接使用from_pretrained()方法加载已经预训练好的模型或者参数。

model_bert.py

官方文档: huggingface bert 说明文档

class BertModel(config)

一个单独的BERT模型 输出raw hidden-state 没有任何specific head. BERT 是一个双向transformer 使用masked language modeling 和 next sentence prediction在TornotoBook和wiipedia上 训练的预训练模型。

参数
  • config(BertConfig):模型configuration class 包含了模型的所有参数。这里只是初始化没有下载所有相关的参数矩阵。通过from_pretrained() 函数来加载模型的所有参数。
输入
  • input_ids: (batch_size, sequence_length), torch.LongTensor 输入序列的indices, 为了匹配pre-training, BERT输入序列格式应该遵循[CLS] X [SEP] X[SEP] 。
    对于一个句子:
    tokens: [CLS] x x x .[SEP]
    token_type_ids: 0 0 0 0 0 0
    对于2个句子:
    tokens: [CLS] x x x x? [SEP] y y y y.[SEP]
    token_type_ids: 0 0 0 0 0 0 0 1 1 1 1 1 1
    BERT 有position embedding,所以建议pad时,在句子的右侧进行pad
  • attention_mask: (optional) (batch_size, sequence_length)torch.FloatTensor.
    Mask for padding tokens。Mask value:[0, 1],其中1 是 不被mask的token 位置,0 是被mask的token.
  • token_type_ids: (batch_size, sequence_length)torch.LongTensor.
    段落token indices, 用于区分输入的第一部分和第二部分。[0, 1]: 0是第一句,1是第二句。
  • position_ids: (batch_size, sequence_length) torch.LongTensor.
    每个输入句子的位置嵌入 positional embeddings.
    输入范围 [0, config.max_position_embeddings-1 ]
  • head_mask: (num_heads, ) 或者(num_layers, num_heads) torch.FloatTensor.
    使选到的self-attention模块无效。Mask value[0,1]: 1是非mask的head, 0是mask的head.
输出

outputs: 是一个tuple

  • last_hidden_state: (batch_size, sequence_length, hidden_size)torch.FloatTensor.
    模型最后一层的hidden-states 序列

  • pooler_output: (batch_size, hidden_size)
    最后一层hidden-state的第一个token(CLS) 通过一个Linear layer 和一个Tanh 激活函数的输出。

  • hidden_states: (当config.output_hidden_state=True时 会说输出)

  • 一个list , 每个元素都是(batch_size, sequence_length, hidden_size)(每层的hidden-states + embedding outputs)

  • attentions: (当config.output_attentions=True时会输出)
    一个list , 里面元素(batch_size, num_heads, sequence_length, sequence_length)(Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.)

forward

forward(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None)

定义一个模型训练方式,这个应该被subclass overwrite.

BertForPreTraining

class transormers.BertForPreTraining(config)
发布了28 篇原创文章 · 获赞 5 · 访问量 4344

猜你喜欢

转载自blog.csdn.net/m0_37531129/article/details/101612062
今日推荐