Bert实现命名实体识别NER任务 Trainer类实现

Bert实现命名实体识别任务

使用Transformers.trainer 进行实现
code_dir:
https://gitee.com/liuyu_1997/ml-nlp/blob/master/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/nlp/BertNER/BertNER.ipynb

1.加载数据

加载数据以及数据的展示,这里使用最常见的conll2003数据集进行实验

task = "ner"  # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "distilbert-base-uncased"
batch_size = 16
from datasets import load_dataset, load_metric,Dataset

datasets = load_dataset("conll2003")

展示数据集的第一条数据

datasets["train"][0]
{'id': '0',
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.'],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}
datasets["train"].features[f"ner_tags"].feature.names
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
datasets["train"].features[f"chunk_tags"].feature.names
['O',
 'B-ADJP',
 'I-ADJP',
 'B-ADVP',
 'I-ADVP',
 'B-CONJP',
 'I-CONJP',
 'B-INTJ',
 'I-INTJ',
 'B-LST',
 'I-LST',
 'B-NP',
 'I-NP',
 'B-PP',
 'I-PP',
 'B-PRT',
 'I-PRT',
 'B-SBAR',
 'I-SBAR',
 'B-UCP',
 'I-UCP',
 'B-VP',
 'I-VP']
datasets["train"].features[f"pos_tags"].feature.names
['"',
 "''",
 '#',
 '$',
 '(',
 ')',
 ',',
 '.',
 ':',
 '``',
 'CC',
 'CD',
 'DT',
 'EX',
 'FW',
 'IN',
 'JJ',
 'JJR',
 'JJS',
 'LS',
 'MD',
 'NN',
 'NNP',
 'NNPS',
 'NNS',
 'NN|SYM',
 'PDT',
 'POS',
 'PRP',
 'PRP$',
 'RB',
 'RBR',
 'RBS',
 'RP',
 'SYM',
 'TO',
 'UH',
 'VB',
 'VBD',
 'VBG',
 'VBN',
 'VBP',
 'VBZ',
 'WDT',
 'WP',
 'WP$',
 'WRB']

2.处理数据

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer不仅可以 对 字符串进行序列化 还可以对 分词后的token进行序列化 需要设置 is_split_into_words=True

tokenizer(["Hello", ",", "this", "is", "one", "sentence", "split", "into", "words", "."], is_split_into_words=True)
{'input_ids': [101, 7592, 1010, 2023, 2003, 2028, 6251, 3975, 2046, 2616, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
tokenizer.decode([101, 7592, 1010, 2023, 2003, 2028, 6251, 3975, 2046, 2616, 1012, 102])
'[CLS] hello, this is one sentence split into words. [SEP]'

值得注意的是 tokenizer可能将 单词 分割成 单词的 词根或词缀 即 经过 tokenizer后 序列的长度可能发生改变

Transformers通常使用子词标记器进行预训练,这意味着即使您的输入已经被分割成单词,这些单词中的每一个都可以被标记器再次分割。让我们看一个例子:

example = datasets["train"][4]
print("原始token:",example["tokens"])
print("-"*100)
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
print("转换后的 token:",tokens)
print("-"*100)
tokenizer.decode(tokenized_input["input_ids"])
原始token: ['Germany', "'s", 'representative', 'to', 'the', 'European', 'Union', "'s", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.']
----------------------------------------------------------------------------------------------------
转换后的 token: ['[CLS]', 'germany', "'", 's', 'representative', 'to', 'the', 'european', 'union', "'", 's', 'veterinary', 'committee', 'werner', 'z', '##wing', '##mann', 'said', 'on', 'wednesday', 'consumers', 'should', 'buy', 'sheep', '##me', '##at', 'from', 'countries', 'other', 'than', 'britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.', '[SEP]']
----------------------------------------------------------------------------------------------------





"[CLS] germany's representative to the european union's veterinary committee werner zwingmann said on wednesday consumers should buy sheepmeat from countries other than britain until the scientific advice was clearer. [SEP]"

这意味着我们需要对标签 做一些处理。 因为 tokenizer 返回的 id 比我们的数据集所包含的标签列表要长,其原因了是 单词被再次拆分
或者 添加了一些特殊的标记 例如 CLS 和 SEP

len(example[task+"_tags"]), len(tokenized_input["input_ids"]) # ner_tags长度 和 input_ids 长度无法匹配
(31, 39)

为此,我们可以使用 tokenized_input.word_ids()方法 来进行操作

print(tokenized_input.word_ids())
print(len(tokenized_input.word_ids()))
[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 11, 11, 12, 13, 14, 15, 16, 17, 18, 18, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, None]
39

正如我们所看到的,它返回一个列表,其中的元素数量与我们处理过的输入id相同,将特殊标记映射为None,将所有其他标记映射为各自的词。
这样,我们就可以将标签与处理后的输入id对齐。(其中 相同的数字表示 由同一个词 拆分而成的 子token )

word_ids = tokenized_input.word_ids()
aligned_labels = [-100 if i is None else example[f"{
      
      task}_tags"][i] for i in word_ids]
print(len(aligned_labels), len(tokenized_input["input_ids"]))
39 39

在这里,我们将所有特殊标记的标签设置为-100(PyTorch所忽略的索引),将所有其他标记的标签设置为它们所来自的单词的标签。另一种策略是只对从一个给定的单词中获得的第一个标记设置标签,而对来自同一单词的其他子标记给予-100的标签。我们在此提出这两种策略,只需改变以下标志的值。

label_all_tokens = True  # True 是第一种策略  Fale 是第二种策略
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{
      
      task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:

            if word_idx is None:
                label_ids.append(-100)

            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs
tokenize_and_align_labels(datasets['train'][:1])
{'input_ids': [[101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100]]}
tokenizer.batch_decode(tokenize_and_align_labels(datasets['train'][:1])["input_ids"])
['[CLS] eu rejects german call to boycott british lamb. [SEP]']

将数据集整体 进行token对齐操作 调用map()

datasets
DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14042
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3251
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3454
    })
})
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)
  0%|          | 0/15 [00:00<?, ?ba/s]



  0%|          | 0/4 [00:00<?, ?ba/s]



  0%|          | 0/4 [00:00<?, ?ba/s]
tokenized_datasets
DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 14042
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 3251
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 3454
    })
})
tokenized_datasets["train"]["labels"][0]
[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100]
tokenized_datasets["train"]["ner_tags"][0]
[3, 0, 7, 0, 0, 0, 7, 0, 0]

3.训练模型

进行模型Fine-tuning

from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
label_list = datasets["train"].features[f"{
      
      task}_tags"].feature.names
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))
Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{
      
      model_name}-finetuned-{
      
      task}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)
! pip install  seqeval
metric = load_metric("seqeval")
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Looking in indexes: https://mirrors.ustc.edu.cn/pypi/web/simple
[33mWARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<pip._vendor.urllib3.connection.HTTPSConnection object at 0x7fc0e6d54cd0>: Failed to establish a new connection: [Errno -2] Name or service not known')': /pypi/web/simple/seqeval/[0m[33m
[0m[33mWARNING: Retrying (Retry(total=3, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<pip._vendor.urllib3.connection.HTTPSConnection object at 0x7fc0e6d6e190>: Failed to establish a new connection: [Errno 101] Network is unreachable')': /pypi/web/simple/seqeval/[0m[33m
[0mCollecting seqeval
  Downloading https://mirrors.bfsu.edu.cn/pypi/web/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 KB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hRequirement already satisfied: numpy>=1.14.0 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from seqeval) (1.19.5)
Requirement already satisfied: scikit-learn>=0.21.3 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from seqeval) (1.0.2)
Requirement already satisfied: scipy>=1.1.0 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval) (1.7.3)
Requirement already satisfied: threadpoolctl>=2.0.0 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval) (2.2.0)
Requirement already satisfied: joblib>=0.11 in /home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval) (1.1.0)
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25ldone
[?25h  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16180 sha256=c2ceaaad2863968f5428daced7949c6ef089729d9617dc5ed07bd6a4893a8ac6
  Stored in directory: /home/zutnlp/.cache/pip/wheels/11/f9/5f/edc55bc2839444a3a60c455e3a9e75879a3e489c06fd92bdf2
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2
[33mWARNING: You are using pip version 22.0.3; however, version 22.1 is available.
You should consider upgrading via the '/home/zutnlp/miniconda3/envs/liuyu/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m

评估算法


import numpy as np

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # 删除忽略的索引(特殊令牌)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
    
    
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

训练

trainer.train()
The following columns in the training set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.
/home/zutnlp/miniconda3/envs/liuyu/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  FutureWarning,
***** Running training *****
  Num examples = 14042
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 2634




<div>

  <progress value='2' max='2634' style='width:300px; height:20px; vertical-align: middle;'></progress>
  [   2/2634 : < :, Epoch 0.00/3]
</div>
<table border="1" class="dataframe">
Epoch Training Loss Validation Loss

Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-500
Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/config.json
Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3251
  Batch size = 16
Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-1000
Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/config.json
Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/special_tokens_map.json
Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-1500
Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-1500/config.json
Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-1500/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-1500/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-1500/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3251
  Batch size = 16
Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-2000
Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-2000/config.json
Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-2000/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-2000/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-2000/special_tokens_map.json
Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-2500
Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-2500/config.json
Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-2500/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-2500/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-2500/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3251
  Batch size = 16


Training completed. Do not forget to share your model on huggingface.co/models =)







TrainOutput(global_step=2634, training_loss=0.08670986667218857, metrics={'train_runtime': 96.0161, 'train_samples_per_second': 438.739, 'train_steps_per_second': 27.433, 'total_flos': 510309848641824.0, 'train_loss': 0.08670986667218857, 'epoch': 3.0})

评估

trainer.evaluate()
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3251
  Batch size = 16


[ 1/204 : < :]

{'eval_loss': 0.061025530099868774,
 'eval_precision': 0.9237063246351173,
 'eval_recall': 0.9345564380803222,
 'eval_f1': 0.9290997052772062,
 'eval_accuracy': 0.9831127774159213,
 'eval_runtime': 2.6055,
 'eval_samples_per_second': 1247.758,
 'eval_steps_per_second': 78.297,
 'epoch': 3.0}

测试

predictions, labels, _ = trainer.predict(tokenized_datasets["test"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results
The following columns in the test set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: id, chunk_tags, pos_tags, ner_tags, tokens. If id, chunk_tags, pos_tags, ner_tags, tokens are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 3454
  Batch size = 16





{'LOC': {'precision': 0.8881818181818182,
  'recall': 0.9199623352165726,
  'f1': 0.9037927844588344,
  'number': 2124},
 'MISC': {'precision': 0.7567567567567568,
  'recall': 0.7309236947791165,
  'f1': 0.7436159346271707,
  'number': 996},
 'ORG': {'precision': 0.8615443134271586,
  'recall': 0.875193199381762,
  'f1': 0.8683151236342725,
  'number': 2588},
 'PER': {'precision': 0.9602673598217601,
  'recall': 0.9514348785871964,
  'f1': 0.9558307152097579,
  'number': 2718},
 'overall_precision': 0.887906647807638,
 'overall_recall': 0.8940185141229527,
 'overall_f1': 0.8909520993494974,
 'overall_accuracy': 0.974760030384642}

猜你喜欢

转载自blog.csdn.net/q506610466/article/details/124759506
今日推荐