pytorch_transformers包含BERT, GPT, GPT-2, Transfo-XL, XLNet, XLM 等多个模型,并提供了27 个预训练模型。
对于每个模型,pytorch_transformers库里都对应有三个类:
- model classes是模型的网络结构
- configuration classes是模型的相关参数
- tokenizer classes是分词工具,一般建议直接使用from_pretrained()方法加载已经预训练好的模型或者参数。
model_bert.py
官方文档: huggingface bert 说明文档
class BertModel(config)
一个单独的BERT模型 输出raw hidden-state 没有任何specific head. BERT 是一个双向transformer 使用masked language modeling 和 next sentence prediction在TornotoBook和wiipedia上 训练的预训练模型。
参数
- config(BertConfig):模型configuration class 包含了模型的所有参数。这里只是初始化没有下载所有相关的参数矩阵。通过from_pretrained() 函数来加载模型的所有参数。
输入
- input_ids: (batch_size, sequence_length), torch.LongTensor 输入序列的indices, 为了匹配pre-training, BERT输入序列格式应该遵循[CLS] X [SEP] X[SEP] 。
对于一个句子:
tokens: [CLS] x x x .[SEP]
token_type_ids: 0 0 0 0 0 0
对于2个句子:
tokens: [CLS] x x x x? [SEP] y y y y.[SEP]
token_type_ids: 0 0 0 0 0 0 0 1 1 1 1 1 1
BERT 有position embedding,所以建议pad时,在句子的右侧进行pad - attention_mask: (optional) (batch_size, sequence_length)torch.FloatTensor.
Mask for padding tokens。Mask value:[0, 1],其中1 是 不被mask的token 位置,0 是被mask的token. - token_type_ids: (batch_size, sequence_length)torch.LongTensor.
段落token indices, 用于区分输入的第一部分和第二部分。[0, 1]: 0是第一句,1是第二句。 - position_ids: (batch_size, sequence_length) torch.LongTensor.
每个输入句子的位置嵌入 positional embeddings.
输入范围 [0, config.max_position_embeddings-1 ] - head_mask: (num_heads, ) 或者(num_layers, num_heads) torch.FloatTensor.
使选到的self-attention模块无效。Mask value[0,1]: 1是非mask的head, 0是mask的head.
输出
outputs: 是一个tuple
-
last_hidden_state: (batch_size, sequence_length, hidden_size)torch.FloatTensor.
模型最后一层的hidden-states 序列 -
pooler_output: (batch_size, hidden_size)
最后一层hidden-state的第一个token(CLS) 通过一个Linear layer 和一个Tanh 激活函数的输出。 -
hidden_states: (当config.output_hidden_state=True时 会说输出)
-
一个list , 每个元素都是(batch_size, sequence_length, hidden_size)(每层的hidden-states + embedding outputs)
-
attentions: (当config.output_attentions=True时会输出)
一个list , 里面元素(batch_size, num_heads, sequence_length, sequence_length)(Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.)
forward
forward(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None)
定义一个模型训练方式,这个应该被subclass overwrite.
BertForPreTraining
class transormers.BertForPreTraining(config)