基于文本挖掘的企业隐患排查质量分析模型

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第20天,点击查看活动详情

基于文本挖掘的企业隐患排查质量分析模型

www.sodic.com.cn/competition…

1 赛题名称

基于文本挖掘的企业隐患排查质量分析模型

2 赛题背景

企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义。企业在填报隐患时,往往存在不认真填报的情况,“虚报、假报”隐患内容,增大了企业监管的难度。采用大数据手段分析隐患内容,找出不切实履行主体责任的企业,向监管部门进行推送,实现精准执法,能够提高监管手段的有效性,增强企业安全责任意识。

3 赛题任务

本赛题提供企业填报隐患数据,参赛选手需通过智能化手段识别其中是否存在“虚报、假报”的情况。

看清赛题很关键,大家需要好好理解赛题目标之后,再去做题,可以避免很多弯路。

4 数据简介

本赛题数据集为脱敏后的企业填报自查隐患记录。

数据说明 训练集数据包含“【id、level_1(一级标准)、level_2(二级标准)、level_3(三级标准)、level_4(四级标准)、content(隐患内容)和label(标签)】”共7个字段。 其中“id”为主键,无业务意义;“一级标准、二级标准、三级标准、四级标准”为《深圳市安全隐患自查和巡查基本指引(2016年修订版)》规定的排查指引,一级标准对应不同隐患类型,二至四级标准是对一级标准的细化,企业自主上报隐患时,根据不同类型隐患的四级标准开展隐患自查工作;“隐患内容”为企业上报的具体隐患;“标签”标识的是该条隐患的合格性,“1”表示隐患填报不合格,“0”表示隐患填报合格。

预测结果文件results.csv

列名 说明
id 企业号
label 正负样本分类
  • 文件名:results.csv,utf-8编码
  • 参赛者以csv/json等文件格式,提交模型结果,平台进行在线评分,实时排名。

5 评测标准

本赛题采用F1 -score作为模型评判标准。

6 赛题解析笔记

1 导入工具包

#  pip install -i https://pypi.tuna.tsinghua.edu.cn/simple transformers
复制代码
# 导入transformers
import transformers
# from transformers import BertModel, BertTokenizer,BertConfig, AdamW, get_linear_schedule_with_warmup

from transformers import AutoModel, AutoTokenizer,AutoConfig, AdamW, get_linear_schedule_with_warmup

# 导入torch
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


# 常用包
import re
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap


%matplotlib inline
%config InlineBackend.figure_format='retina' # 主题
复制代码
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8

# 固定随机种子
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
复制代码
device(type='cuda', index=0)
复制代码
torch.cuda.is_available()
复制代码
True
复制代码

2 加载数据

train=pd.read_csv('/home/mw/input/task026741/sub.csv')
test=pd.read_csv('/home/mw/input/task026741/test.csv')
sub=pd.read_csv('/home/mw/input/task026741/train.csv')
复制代码
# train=pd.read_csv('data/02/train.csv')
# test=pd.read_csv('data/02/test.csv')
# sub=pd.read_csv('data/02/sub.csv')

print("train.shape,test.shape,sub.shape",train.shape,test.shape,sub.shape)
复制代码
train.shape,test.shape,sub.shape (12000, 7) (18000, 6) (18000, 2)
复制代码
# 查看前三行
train.head(3)
复制代码
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
id level_1 level_2 level_3 level_4 content label
0 0 工业/危化品类(现场)—2016版 (二)电气安全 6、移动用电产品、电动工具及照明 1、移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 0
1 1 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 3、消防设施、器材和消防安全标志是否在位、完整; 一般 1
2 2 工业/危化品类(现场)—2016版 (一)消防检查 2、防火检查 6、重点工种人员以及其他员工消防知识的掌握情况; 消防知识要加强 0

2.1 查看缺失值

train.info()
复制代码
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 12000 entries, 0 to 11999
Data columns (total 7 columns):
 #   Column   Non-Null Count  Dtype 
---  ------   --------------  ----- 
 0   id       12000 non-null  int64 
 1   level_1  12000 non-null  object
 2   level_2  12000 non-null  object
 3   level_3  12000 non-null  object
 4   level_4  12000 non-null  object
 5   content  11998 non-null  object
 6   label    12000 non-null  int64 
dtypes: int64(2), object(5)
memory usage: 656.4+ KB
复制代码
train[train['content'].isna()] # content 非常重要的字段
复制代码
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
id level_1 level_2 level_3 level_4 content label
6193 6193 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 3、消防设施、器材和消防安全标志是否在位、完整; NaN 1
9248 9248 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 4、常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; NaN 1
test.info() # 4 条content为空
复制代码
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 18000 entries, 0 to 17999
Data columns (total 6 columns):
 #   Column   Non-Null Count  Dtype 
---  ------   --------------  ----- 
 0   id       18000 non-null  int64 
 1   level_1  18000 non-null  object
 2   level_2  18000 non-null  object
 3   level_3  18000 non-null  object
 4   level_4  18000 non-null  object
 5   content  17996 non-null  object
dtypes: int64(1), object(5)
memory usage: 843.9+ KB
复制代码
print("train null nums")
print(train.shape[0]-train.count())
print("test null nums")
print(test.shape[0]-test.count())
复制代码
train null nums
id         0
level_1    0
level_2    0
level_3    0
level_4    0
content    2
label      0
dtype: int64
test null nums
id         0
level_1    0
level_2    0
level_3    0
level_4    0
content    4
dtype: int64
复制代码

2.2 标签分布

tip:NLP所有任务,首先要看下答案或者标签的分布 分类任务,每个类别分布;回归任务,具体数值分布;NER任务,需要标注标签分布。。

train['label'].value_counts()
复制代码
0    10712
1     1288
Name: label, dtype: int64
复制代码
sns.countplot(train.label)
plt.xlabel('label count')
复制代码
Text(0.5, 0, 'label count')




复制代码

output_18_1.png

1288/10712
复制代码
0.12023898431665422
复制代码

3 数据预处理

# 填充缺失值
train['content']=train['content'].fillna('空值')
test['content']=test['content'].fillna('空值')
复制代码
train['level_1']=train['level_1'].apply(lambda x:x.split('(')[0])
train['level_2']=train['level_2'].apply(lambda x:x.split(')')[-1])
train['level_3']=train['level_3'].apply(lambda x:re.split(r'[0-9]、',x)[-1])
train['level_4']=train['level_4'].apply(lambda x:re.split(r'[0-9]、',x)[-1])

test['level_1']=test['level_1'].apply(lambda x:x.split('(')[0])
test['level_2']=test['level_2'].apply(lambda x:x.split(')')[-1])
test['level_3']=test['level_3'].apply(lambda x:re.split(r'[0-9]、',x)[-1])
test['level_4']=test['level_4'].apply(lambda x:re.split(r'[0-9]、',x)[-1])
复制代码
train
复制代码
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
id level_1 level_2 level_3 level_4 content label
0 0 工业/危化品类 电气安全 移动用电产品、电动工具及照明 移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 0
1 1 工业/危化品类 消防检查 防火巡查 消防设施、器材和消防安全标志是否在位、完整; 一般 1
2 2 工业/危化品类 消防检查 防火检查 重点工种人员以及其他员工消防知识的掌握情况; 消防知识要加强 0
3 3 工业/危化品类 消防检查 防火巡查 消防设施、器材和消防安全标志是否在位、完整; 消防通道有货物摆放 清理不及时 0
4 4 工业/危化品类 消防检查 防火巡查 常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 防火门打开状态 0
... ... ... ... ... ... ... ...
11995 11995 商贸服务教文卫类 安全教育培训 员工安全教育 制定安全教育培训计划,确保全员参与培训,并建立安全培训档案。 个别员工对消防栓的使用不熟练\r\n 0
11996 11996 工业/危化品类 电气安全 电气线路及电源插头插座 电源插座、电源插头应按规定正确接线。 化验室超净台照明电源线防护不足,且检测台金属架未安装漏电接地保护线。整改措施:更换照明灯为前... 0
11997 11997 工业/危化品类 机械设备安全防护 人身防护 皮带轮、齿轮、凸轮、曲柄连杆机构等外露的转动和运动部件应有防护罩。 电箱、马达,没有防护罩,现在整改 0
11998 11998 工业/危化品类 作业环境 通风与照明 作业场所通风良好。 D1部车间二楼配胶房排风扇未开启。 0
11999 11999 纯办公场所 消防安全 消防通道 疏散通道无占用、堵塞、封闭等现象。安全出口不得上锁。 已整改 1

12000 rows × 7 columns

# train['text']=train['content']+' '+train['level_1']+' '+train['level_2']+' '+train['level_3']+' '+train['level_4']
# test['text']=test['content']+' '+test['level_1']+' '+test['level_2']+' '+test['level_3']+' '+test['level_4']
train['text']=train['content']+'[SEP]'+train['level_1']+'[SEP]'+train['level_2']+'[SEP]'+train['level_3']+'[SEP]'+train['level_4']
test['text']=test['content']+'[SEP]'+test['level_1']+'[SEP]'+test['level_2']+'[SEP]'+test['level_3']+'[SEP]'+test['level_4']
复制代码
train.head()
复制代码
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
id level_1 level_2 level_3 level_4 content label text
0 0 工业/危化品类 电气安全 移动用电产品、电动工具及照明 移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 0 使用移动手动电动工具,外接线绝缘皮破损,应停止使用.[SEP]工业/危化品类[SEP]电气安...
1 1 工业/危化品类 消防检查 防火巡查 消防设施、器材和消防安全标志是否在位、完整; 一般 1 一般[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]消防设施、器材和消...
2 2 工业/危化品类 消防检查 防火检查 重点工种人员以及其他员工消防知识的掌握情况; 消防知识要加强 0 消防知识要加强[SEP]工业/危化品类[SEP]消防检查[SEP]防火检查[SEP]重点工种...
3 3 工业/危化品类 消防检查 防火巡查 消防设施、器材和消防安全标志是否在位、完整; 消防通道有货物摆放 清理不及时 0 消防通道有货物摆放 清理不及时[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[...
4 4 工业/危化品类 消防检查 防火巡查 常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 防火门打开状态 0 防火门打开状态[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]常闭式防...

3.1 文本长度分布

train['text_len']=train['text'].map(len)
复制代码
train['text'].map(len).describe()# 298-12=286
复制代码
count    12000.000000
mean        80.439833
std         21.913662
min         43.000000
25%         66.000000
50%         75.000000
75%         92.000000
max        298.000000
Name: text, dtype: float64
复制代码
test['text'].map(len).describe() # 520-12=518
复制代码
count    18000.000000
mean        80.762611
std         22.719823
min         43.000000
25%         66.000000
50%         76.000000
75%         92.000000
max        520.000000
Name: text, dtype: float64
复制代码
train['text_len'].plot(kind='kde')
复制代码
<AxesSubplot:ylabel='Density'>




复制代码

output_30_1.png

sum(train['text_len']>100) # text文本长度大于100的个数
sum(train['text_len']>200) # text文本长度大于200的个数
复制代码
11
复制代码
1878/len(train)
复制代码
0.1565
复制代码

4 认识Tokenizer

4.1 将文本映射为id表示

PRE_TRAINED_MODEL_NAME = 'bert-base-chinese'
# PRE_TRAINED_MODEL_NAME = 'hfl/chinese-roberta-wwm-ext'
# PRE_TRAINED_MODEL_NAME = 'hfl/chinese-roberta-wwm'

# tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
复制代码
# tokenizer = BertTokenizer.from_pretrained('C:\\Users\\yanqiang\\Desktop\\bert-base-chinese')
复制代码
tokenizer
复制代码
PreTrainedTokenizerFast(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})
复制代码
  • 可以看到BertTokenizer的词表大小为21128
  • 特殊符号为special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}

我们尝试使用BertTokenizer进行分词

sample_txt = '今天早上9点半起床,我在学习预训练模型的使用.'
复制代码
len(sample_txt)
复制代码
23
复制代码
tokens = tokenizer.tokenize(sample_txt)
token_ids = tokenizer.convert_tokens_to_ids(tokens)

print(f'文本为: {sample_txt}')
print(f'分词的列表为: {tokens}')
print(f'词对应的唯一id: {token_ids}')
复制代码
文本为: 今天早上9点半起床,我在学习预训练模型的使用.
分词的列表为: ['今', '天', '早', '上', '9', '点', '半', '起', '床', ',', '我', '在', '学', '习', '预', '训', '练', '模', '型', '的', '使', '用', '.']
词对应的唯一id: [791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769, 1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119]
复制代码

4.2 特殊符号

tokenizer.sep_token, tokenizer.sep_token_id
复制代码
('[SEP]', 102)
复制代码
tokenizer.unk_token, tokenizer.unk_token_id
复制代码
('[UNK]', 100)
复制代码
tokenizer.pad_token, tokenizer.pad_token_id
复制代码
('[PAD]', 0)
复制代码
tokenizer.cls_token, tokenizer.cls_token_id
复制代码
('[CLS]', 101)
复制代码
tokenizer.mask_token, tokenizer.mask_token_id
复制代码
('[MASK]', 103)
复制代码

可以使用 encode_plus() 对句子进行分词,添加特殊符号

encoding=tokenizer.encode_plus(
    sample_txt,
    # sample_txt_another,
    max_length=32,
    add_special_tokens=True,# [CLS]和[SEP]
    return_token_type_ids=True,
    pad_to_max_length=True,
    return_attention_mask=True,
    return_tensors='pt',# Pytorch tensor张量

)
encoding.keys()
复制代码
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
F:\ProgramData\Anaconda3\lib\site-packages\transformers\tokenization_utils_base.py:2271: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(





dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
复制代码
encoding
复制代码
{'input_ids': tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0]])}
复制代码

token ids的长度为32的张量

print(len(encoding['input_ids'][0]))
复制代码
32
复制代码

attention mask具有同样的长度

print(len(encoding['attention_mask'][0]))
encoding['attention_mask']
复制代码
32





tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0]])
复制代码

我们将ids反转为词语,可以打印下每个字符是什么?

tokenizer.convert_ids_to_tokens(encoding['input_ids'][0])
复制代码
['[CLS]', '今', '天', '早', '上', '9', '点', '半', '起', '床', ',', '我', '在', '学', '习', '预', '训', '练', '模', '型', '的', '使', '用', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
复制代码

4.2 选取文本最大长度

token_lens = []

for txt in train.text:
    tokens = tokenizer.encode(txt, max_length=512)
    token_lens.append(len(tokens))
复制代码
sns.distplot(token_lens)
plt.xlim([0, 256]);
plt.xlabel('Token count');
复制代码
F:\ProgramData\Anaconda3\lib\site-packages\seaborn\distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)



复制代码

output_58_1.png

可以看到大多数文本的ids长度在100以内,我们设置最大长度为160

MAX_LEN = 160
复制代码

5 构建数据集

5.1 自定义数据集

class EnterpriseDataset(Dataset):
    def __init__(self,texts,labels,tokenizer,max_len):
        self.texts=texts
        self.labels=labels
        self.tokenizer=tokenizer
        self.max_len=max_len
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self,item):
        """
        item 为数据索引,迭代取第item条数据
        """
        text=str(self.texts[item])
        label=self.labels[item]
        
        encoding=self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=True,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        
#         print(encoding['input_ids'])
        return {
            'texts':text,
            'input_ids':encoding['input_ids'].flatten(),
            'attention_mask':encoding['attention_mask'].flatten(),
            # toeken_type_ids:0
            'labels':torch.tensor(label,dtype=torch.long)
        }
        
复制代码

5.2 划分数据集并创建生成器

df_train, df_test = train_test_split(train, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED)
df_train.shape, df_val.shape, df_test.shape
复制代码
((10800, 9), (600, 9), (600, 9))
复制代码
def create_data_loader(df,tokenizer,max_len,batch_size):
    ds=EnterpriseDataset(
        texts=df['text'].values,
        labels=df['label'].values,
        tokenizer=tokenizer,
        max_len=max_len
    )
    
    return DataLoader(
        ds,
        batch_size=batch_size,
#         num_workers=4 # windows多线程
    )
复制代码
BATCH_SIZE = 4

train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
复制代码
next(iter(train_data_loader))
复制代码
F:\ProgramData\Anaconda3\lib\site-packages\transformers\tokenization_utils_base.py:2271: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(





{'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。',
  '安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'],
 'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399,  679, 3926, 3504,  102, 2339,  689,  120,
          1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125,
          2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887,
          3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403,
          2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024,
          2347, 2128, 2961, 6579,  743, 4127, 4125, 1690, 3291, 2940,  102, 1555,
          6588, 3302, 1218, 3136, 3152, 1310, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 3466, 3389,  102, 4127, 4125, 1690, 3332, 6981, 5390,
          1350, 3300, 3126, 2658, 1105,  511,  102,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300,  671,  702, 3300,
          3125, 7397, 8024, 2347,  743, 1726, 2128, 6163, 3121, 3633,  511,  102,
          2339,  689,  120, 1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541,
          3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141,
          2900, 4850, 3403, 2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130,
          1962, 8039,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1843,  749, 3867, 7344, 6858, 6887,  102, 2339,  689,  120, 1314,
          1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125, 2337,
          3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887, 3221,
          1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562,
           510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([0, 0, 0, 0])}
复制代码
data = next(iter(train_data_loader))
data.keys()
复制代码
dict_keys(['texts', 'input_ids', 'attention_mask', 'labels'])
复制代码
print(data['input_ids'].shape)
print(data['attention_mask'].shape)
print(data['labels'].shape)
复制代码
torch.Size([4, 160])
torch.Size([4, 160])
torch.Size([4])
复制代码

6 基于Huggingface 的企业隐患识别模型构建

# bert_model = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
bert_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
复制代码
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
复制代码
bert_model
复制代码
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (2): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (3): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (4): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (5): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (6): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (7): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (8): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (9): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (10): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (11): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
复制代码
encoding
复制代码
{'input_ids': tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0]])}
复制代码
last_hidden_state, pooled_output = bert_model(
    input_ids=encoding['input_ids'], 
    attention_mask=encoding['attention_mask'],
    return_dict = False
)
复制代码
last_hidden_state # 每个token的向量表示
复制代码
tensor([[[ 0.8880,  0.1987,  1.3610,  ..., -0.5096,  0.3742, -0.2368],
         [-0.0747,  0.3148,  1.4699,  ..., -1.0238, -0.0518, -0.0557],
         [ 1.0133, -0.6058,  1.0152,  ...,  0.3536,  1.1091, -0.1179],
         ...,
         [ 0.4613,  0.4155, -0.4329,  ...,  0.1605, -0.3617, -0.2294],
         [ 0.4403,  0.4763, -0.5568,  ...,  0.2216, -0.3297, -0.3064],
         [ 0.4437,  0.3844, -0.4880,  ...,  0.0670, -0.5105, -0.2472]]],
       grad_fn=<NativeLayerNormBackward>)
复制代码
pooled_output
复制代码
tensor([[ 0.9999,  0.9998,  0.9989,  0.9629,  0.3075, -0.1866, -0.9904,  0.8628,
          0.9710, -0.9993,  1.0000,  1.0000,  0.9312, -0.9394,  0.9998, -0.9999,
          0.0417,  0.9999,  0.9458,  0.3190,  1.0000, -1.0000, -0.9062, -0.9048,
          0.1764,  0.9983,  0.9346, -0.8122, -0.9999,  0.9996,  0.7879,  0.9999,
          0.8475, -1.0000, -1.0000,  0.9413, -0.8260,  0.9889, -0.4976, -0.9857,
         -0.9955, -0.9580,  0.5833, -0.9996, -0.8932,  0.8563, -1.0000, -0.9999,
          0.9719,  0.9999, -0.7430, -0.9993,  0.9756, -0.9754,  0.2991,  0.8933,
         -0.9991,  0.9987,  1.0000,  0.4156,  0.9992, -0.9452, -0.8020, -0.9999,
          1.0000, -0.9964, -0.9900,  0.4365,  1.0000,  1.0000, -0.9400,  0.8794,
          1.0000,  0.9105, -0.6616,  1.0000, -0.9999,  0.6892, -1.0000, -0.9817,
          1.0000,  0.9957, -0.8844, -0.8248, -0.9921, -0.9999, -0.9998,  1.0000,
          0.5228,  0.1297,  0.9932, -0.9999, -1.0000,  0.9993, -0.9996, -0.9948,
         -0.9561,  0.9996, -0.5785, -0.9386, -0.2035,  0.9086, -0.9999, -0.9993,
          0.9959,  0.9984,  0.6953, -0.9995,  1.0000,  0.8610, -1.0000, -0.4507,
         -1.0000,  0.2384, -0.9812,  0.9998,  0.9504,  0.5421,  0.9995, -0.9998,
          0.9320, -0.9941, -0.9718, -0.9910,  0.9822,  1.0000,  0.9997, -0.9990,
          1.0000,  1.0000,  0.8608,  0.9964, -0.9997,  0.9799,  0.5985, -0.9098,
          0.5329, -0.6345,  1.0000,  0.9872,  0.9970, -0.9719,  0.9988, -0.9933,
          1.0000, -0.9999,  0.9973, -1.0000, -0.6550,  0.9996,  0.8899,  1.0000,
          0.2969,  0.9999, -0.9983, -0.9991,  0.9906, -0.6590,  0.9872, -1.0000,
          0.7658,  0.7876, -0.8556,  0.6304, -1.0000,  1.0000, -0.7938,  1.0000,
          0.9898,  0.2216, -0.9942, -0.9969,  0.8345, -0.9998, -0.9779,  0.9914,
          0.5227,  0.9992, -0.9893, -0.9889,  0.2325, -0.9887, -0.9999,  0.9885,
          0.0340,  0.9284,  0.5197,  0.4143,  0.8315,  0.1585, -0.5348,  1.0000,
          0.2361,  0.9985,  0.9999, -0.3446,  0.1012, -0.9924, -1.0000, -0.7542,
          0.9999, -0.2807, -0.9999,  0.9490, -1.0000,  0.9906, -0.7288, -0.5263,
         -0.9545, -0.9999,  0.9998, -0.9286, -0.9997, -0.5303,  0.8886,  0.5605,
         -0.9989, -0.3324,  0.9804, -0.9075,  0.9905, -0.9800, -0.9946,  0.6856,
         -0.9393,  0.9929,  0.9874,  1.0000,  0.9997, -0.0714, -0.9440,  1.0000,
          0.1676, -1.0000,  0.5573, -0.9611,  0.8835,  0.9999, -0.9980,  0.9294,
          1.0000,  0.7968,  1.0000, -0.7065, -0.9793, -0.9997,  1.0000,  0.9922,
          0.9999, -0.9984, -0.9995, -0.1701, -0.5426, -1.0000, -1.0000, -0.6334,
          0.9969,  0.9999, -0.1620, -0.9818, -0.9921, -0.9994,  1.0000, -0.9759,
          1.0000,  0.8570, -0.7434, -0.9164,  0.9438, -0.7311, -0.9986, -0.3936,
         -0.9997, -0.9650, -1.0000,  0.9433, -0.9999, -1.0000,  0.6913,  1.0000,
          0.8762, -1.0000,  0.9997,  0.9764,  0.7094, -0.9294,  0.9522, -1.0000,
          1.0000, -0.9965,  0.9428, -0.9972, -0.9897, -0.7680,  0.9922,  0.9999,
         -0.9999, -0.9597, -0.9922, -0.9807, -0.3632,  0.9936, -0.7280,  0.4117,
         -0.9498, -0.9666,  0.9545, -0.9957, -0.9970,  0.4028,  1.0000, -0.9798,
          1.0000,  0.9941,  1.0000,  0.9202, -0.9942,  0.9996,  0.5352, -0.5836,
         -0.8829, -0.9418,  0.9497, -0.0532,  0.6966, -0.9999,  0.9998,  0.9917,
          0.9612,  0.7289,  0.0167,  0.3179,  0.9627, -0.9911,  0.9995, -0.9996,
         -0.6737,  0.9991,  1.0000,  0.9932,  0.4880, -0.7488,  0.9986, -0.9961,
          0.9995, -1.0000,  0.9999, -0.9940,  0.9705, -0.9970, -0.9856,  1.0000,
          0.9846, -0.7932,  0.9997, -0.9386,  0.9938,  0.9738,  0.8173,  0.9913,
          0.9981,  1.0000, -0.9998, -0.9918, -0.9727, -0.9987, -0.9955, -1.0000,
         -0.1038, -1.0000, -0.9874, -0.9287,  0.5109, -0.9056,  0.1022,  0.7864,
         -0.8197,  0.5724, -0.5905,  0.2713, -0.7239, -0.9976, -0.9844, -1.0000,
         -0.9988,  0.8835,  0.9999, -0.9997,  0.9999, -0.9999, -0.9782,  0.9383,
         -0.5609,  0.7721,  0.9999, -1.0000,  0.9585,  0.9987,  1.0000,  0.9960,
          0.9993, -0.9741, -0.9999, -0.9989, -0.9999, -1.0000, -0.9998,  0.9343,
          0.6337, -1.0000,  0.0902,  0.8980,  1.0000,  0.9964, -0.9985, -0.6136,
         -0.9996, -0.8252,  0.9996, -0.0566, -1.0000,  0.9962, -0.8744,  1.0000,
         -0.8865,  0.9879,  0.8897,  0.9571,  0.9823, -1.0000,  0.9145,  1.0000,
          0.0365, -1.0000, -0.9985, -0.9075, -0.9998,  0.0369,  0.8120,  0.9999,
         -1.0000, -0.9155, -0.9975,  0.7988,  0.9922,  0.9998,  0.9982,  0.9267,
          0.9165,  0.5368,  0.1464,  0.9998,  0.4663, -0.9989,  0.9996, -0.7952,
          0.4527, -1.0000,  0.9998,  0.4073,  0.9999,  0.9159, -0.5480, -0.6822,
         -0.9904,  0.9938,  1.0000, -0.4229, -0.4845, -0.9981, -1.0000, -0.9861,
         -0.0950, -0.4625, -0.9629, -0.9998,  0.6675, -0.5244,  1.0000,  1.0000,
          0.9924, -0.9253, -0.9974,  0.9974, -0.9012,  0.9900, -0.2582, -1.0000,
         -0.9919, -0.9986,  1.0000, -0.9716, -0.9262, -0.9911, -0.2593,  0.5919,
         -0.9999, -0.4994, -0.9962,  0.9818,  1.0000, -0.9996,  0.9918, -0.9970,
          0.7085, -0.1369,  0.8077,  0.9955, -0.3394, -0.5860, -0.6887, -0.9841,
          0.9970,  0.9987, -0.9948, -0.8401,  0.9999,  0.0856,  0.9999,  0.5099,
          0.9466,  0.9567,  1.0000,  0.8771,  1.0000, -0.0815,  1.0000,  0.9999,
         -0.9392,  0.5744,  0.8723, -0.9686,  0.5958,  0.9822,  0.9997,  0.8854,
         -0.1952, -0.9967,  0.9994,  1.0000,  1.0000, -0.3391,  0.9883, -0.4452,
          0.9252,  0.4495,  0.9870,  0.3479,  0.2266,  0.9942,  0.9990, -0.9999,
         -0.9999, -1.0000,  1.0000,  0.9996, -0.6637, -1.0000,  0.9999,  0.4543,
          0.7471,  0.9983,  0.3772, -0.9812,  0.9853, -0.9995, -0.3404,  0.9788,
          0.9867,  0.7564,  0.9995, -0.9997,  0.7990,  1.0000,  0.0752,  0.9999,
          0.2912, -0.9941,  0.9970, -0.9935, -0.9995, -0.9743,  0.9991,  0.9981,
         -0.9273, -0.8402,  0.9996, -0.9999,  0.9999, -0.9998,  0.9724, -0.9939,
          1.0000, -0.9752, -0.9998, -0.3806,  0.8830,  0.8352, -0.8892,  1.0000,
         -0.8875, -0.8107,  0.7083, -0.8909, -0.9931, -0.9630,  0.0800, -1.0000,
          0.7777, -0.9611,  0.5867, -0.9947, -0.9999,  1.0000, -0.9084, -0.9414,
          0.9999, -0.8838, -1.0000,  0.9549, -0.9999, -0.6522,  0.7967, -0.6850,
          0.1524, -1.0000,  0.4800,  0.9999, -0.9998, -0.7089, -0.9129, -0.9864,
          0.6220,  0.8855,  0.9855, -0.8651,  0.3988, -0.2548,  0.9793, -0.7212,
         -0.2582, -0.9999, -0.8692, -0.6282, -0.9999, -0.9999, -1.0000,  1.0000,
          0.9996,  0.9999, -0.5600,  0.7442,  0.9460,  0.9927, -0.9999,  0.4407,
         -0.0461,  0.9937, -0.4887, -0.9994, -0.9198, -1.0000, -0.6905,  0.3538,
         -0.7728,  0.6622,  1.0000,  0.9999, -0.9999, -0.9994, -0.9995, -0.9979,
          0.9998,  0.9999,  0.9996, -0.9072, -0.5844,  0.9997,  0.9689,  0.5231,
         -0.9999, -0.9981, -0.9999,  0.7505, -0.9922, -0.9986,  0.9971,  1.0000,
          0.8730, -1.0000, -0.9533,  1.0000,  0.9997,  1.0000, -0.7768,  0.9999,
         -0.9838,  0.9819, -0.9993,  1.0000, -1.0000,  1.0000,  0.9999,  0.9809,
          0.9984, -0.9928,  0.9776, -0.9998, -0.7407,  0.9298, -0.4495, -0.9902,
          0.8053,  0.9996, -0.9952,  1.0000,  0.9243, -0.2028,  0.8002,  0.9873,
          0.9419, -0.6913, -0.9999,  0.8162,  0.9995,  0.9509,  1.0000,  0.9177,
          0.9996, -0.9839, -0.9998,  0.9914, -0.6991, -0.7821, -0.9998,  1.0000,
          1.0000, -0.9999, -0.9227,  0.7483,  0.1186,  1.0000,  0.9963,  0.9971,
          0.9857,  0.3887,  0.9996, -0.9999,  0.8526, -0.9980, -0.8613,  0.9999,
         -0.9899,  0.9999, -0.9981,  1.0000, -0.9858,  0.9944,  0.9989,  0.9684,
         -0.9968,  1.0000,  0.8246, -0.9956, -0.8348, -0.9374, -0.9999,  0.7827]],
       grad_fn=<TanhBackward>)
复制代码
last_hidden_state.shape # 每个token的向量表示
复制代码
torch.Size([1, 32, 768])
复制代码
pooled_output.shape # CLS的向量表示
复制代码
torch.Size([1, 768])
复制代码
bert_model.config.hidden_size
复制代码
768
复制代码
pooled_output.shape
# 整体句子表示
复制代码
torch.Size([1, 768])
复制代码
class EnterpriseDangerClassifier(nn.Module):
    def __init__(self, n_classes):
        super(EnterpriseDangerClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes) # 两个类别
    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict = False
        )
        output = self.drop(pooled_output) # dropout
        return self.out(output)
复制代码
class_names=[0,1]
复制代码
model = EnterpriseDangerClassifier(len(class_names))
model = model.to(device)
复制代码
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
复制代码
data
复制代码
{'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。',
  '安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'],
 'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399,  679, 3926, 3504,  102, 2339,  689,  120,
          1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125,
          2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887,
          3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403,
          2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024,
          2347, 2128, 2961, 6579,  743, 4127, 4125, 1690, 3291, 2940,  102, 1555,
          6588, 3302, 1218, 3136, 3152, 1310, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 3466, 3389,  102, 4127, 4125, 1690, 3332, 6981, 5390,
          1350, 3300, 3126, 2658, 1105,  511,  102,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300,  671,  702, 3300,
          3125, 7397, 8024, 2347,  743, 1726, 2128, 6163, 3121, 3633,  511,  102,
          2339,  689,  120, 1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541,
          3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141,
          2900, 4850, 3403, 2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130,
          1962, 8039,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1843,  749, 3867, 7344, 6858, 6887,  102, 2339,  689,  120, 1314,
          1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125, 2337,
          3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887, 3221,
          1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562,
           510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([0, 0, 0, 0])}
复制代码
input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)

print(input_ids.shape) # batch size x seq length
print(attention_mask.shape) # batch size x seq length
复制代码
torch.Size([4, 160])
torch.Size([4, 160])
复制代码
model(input_ids, attention_mask)
复制代码
tensor([[-0.3011, -0.3009],
        [ 0.2871,  0.1841],
        [ 0.2703, -0.0926],
        [-0.3193, -0.1487]], device='cuda:0', grad_fn=<AddmmBackward>)
复制代码
F.softmax(model(input_ids, attention_mask), dim=1)
复制代码
tensor([[0.6495, 0.3505],
        [0.6752, 0.3248],
        [0.7261, 0.2739],
        [0.4528, 0.5472]], device='cuda:0', grad_fn=<SoftmaxBackward>)
复制代码

7 模型训练

EPOCHS = 10 # 训练轮数

optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS

scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)

loss_fn = nn.CrossEntropyLoss().to(device)
复制代码
F:\ProgramData\Anaconda3\lib\site-packages\transformers\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
复制代码
def train_epoch(
  model, 
  data_loader, 
  loss_fn, 
  optimizer, 
  device, 
  scheduler, 
  n_examples
):
    model = model.train() # train模式
    losses = []
    correct_predictions = 0
    for d in data_loader:
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        targets = d["labels"].to(device)
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, targets)
        correct_predictions += torch.sum(preds == targets)
        losses.append(loss.item())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    return correct_predictions.double() / n_examples, np.mean(losses)
复制代码
def eval_model(model, data_loader, loss_fn, device, n_examples):
    model = model.eval() # 验证预测模式

    losses = []
    correct_predictions = 0

    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            _, preds = torch.max(outputs, dim=1)

            loss = loss_fn(outputs, targets)

            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses)

复制代码

history = defaultdict(list) # 记录10轮loss和acc
best_accuracy = 0

for epoch in range(EPOCHS):

    print(f'Epoch {epoch + 1}/{EPOCHS}')
    print('-' * 10)

    train_acc, train_loss = train_epoch(
        model,
        train_data_loader,
        loss_fn,
        optimizer,
        device,
        scheduler,
        len(df_train)
    )

    print(f'Train loss {train_loss} accuracy {train_acc}')

    val_acc, val_loss = eval_model(
        model,
        val_data_loader,
        loss_fn,
        device,
        len(df_val)
    )

    print(f'Val   loss {val_loss} accuracy {val_acc}')
    print()

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_acc'].append(val_acc)
    history['val_loss'].append(val_loss)

    if val_acc > best_accuracy:
        torch.save(model.state_dict(), 'best_model_state.bin')
        best_accuracy = val_acc

复制代码
Epoch 1/10
----------
Train loss 0.49691140200114914 accuracy 0.8901851851851852
Val   loss 0.40999091763049367 accuracy 0.9

Epoch 2/10
----------
Train loss 0.3062430267758383 accuracy 0.9349999999999999
Val   loss 0.20030112245275328 accuracy 0.9650000000000001

Epoch 3/10
----------
Train loss 0.18264216477097728 accuracy 0.9603703703703703
Val   loss 0.18755523634143173 accuracy 0.9650000000000001

Epoch 4/10
----------
Train loss 0.15700688022613543 accuracy 0.9693518518518518
Val   loss 0.20371213133369262 accuracy 0.9633333333333334

Epoch 5/10
----------
Train loss 0.1627817107436756 accuracy 0.9668518518518519
Val   loss 0.16456402061972766 accuracy 0.9683333333333334

Epoch 6/10
----------
Train loss 0.15311389193888453 accuracy 0.9721296296296296
Val   loss 0.1188539441426595 accuracy 0.9783333333333334

Epoch 7/10
----------
Train loss 0.13947947008179012 accuracy 0.9734259259259259
Val   loss 0.12033098526764661 accuracy 0.9783333333333334

Epoch 8/10
----------
Train loss 0.12078767392419482 accuracy 0.9781481481481481
Val   loss 0.12014915000802527 accuracy 0.9733333333333334

Epoch 9/10
----------
Train loss 0.11557375699952967 accuracy 0.9751851851851852
Val   loss 0.12187736847476724 accuracy 0.9766666666666667

Epoch 10/10
----------
Train loss 0.10247013699765645 accuracy 0.977037037037037
Val   loss 0.11501088156461871 accuracy 0.9766666666666667

复制代码
plt.plot(history['train_acc'], label='train accuracy')
plt.plot(history['val_acc'], label='validation accuracy')

plt.title('Training history')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.ylim([0, 1]);
复制代码

output_94_0.png

# model = EnterpriseDangerClassifier(len(class_names))
# model.load_state_dict(torch.load('best_model_state.bin'))
# model = model.to(device)
复制代码

8 模型评估

test_acc, _ = eval_model(
  model,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)

test_acc.item()
复制代码
0.9716666666666667
复制代码
def get_predictions(model, data_loader):
    model = model.eval()

    raw_texts = []
    predictions = []
    prediction_probs = []
    real_values = []

    with torch.no_grad():
        for d in data_loader:
            texts = d["texts"]
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            _, preds = torch.max(outputs, dim=1) # 类别

            probs = F.softmax(outputs, dim=1) # 概率

            raw_texts.extend(texts)
            predictions.extend(preds)
            prediction_probs.extend(probs)
            real_values.extend(targets)

    predictions = torch.stack(predictions).cpu()
    prediction_probs = torch.stack(prediction_probs).cpu()
    real_values = torch.stack(real_values).cpu()
    return raw_texts, predictions, prediction_probs, real_values

复制代码
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)
复制代码
print(classification_report(y_test, y_pred, target_names=[str(label) for label in class_names])) # 分类报告
复制代码
              precision    recall  f1-score   support

           0       0.99      0.98      0.98       554
           1       0.81      0.83      0.82        46

    accuracy                           0.97       600
   macro avg       0.90      0.90      0.90       600
weighted avg       0.97      0.97      0.97       600

复制代码
def show_confusion_matrix(confusion_matrix):
    hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
    hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
    hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
    plt.ylabel('True label')
    plt.xlabel('Predicted label');


cm = confusion_matrix(y_test, y_pred)
df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
show_confusion_matrix(df_cm)

复制代码

output_101_0.png

idx = 2

sample_text = y_texts[idx]
true_label = y_test[idx]
pred_df = pd.DataFrame({
  'class_names': class_names,
  'values': y_pred_probs[idx]
})
复制代码
print("\n".join(wrap(sample_text)))
print()
print(f'True label: {class_names[true_label]}')
复制代码
焊锡员工未佩戴防护口罩 工业/危化品类 主要负责人、分管负责人及管理人员履职情况 分管负责人履职情况
分管负责人依法履行安全管理职责(存在职业健康危害的单位需自查职业卫生履职情况)。

True label: 0
复制代码
sns.barplot(x='values', y='class_names', data=pred_df, orient='h')
plt.ylabel('sentiment')
plt.xlabel('probability')
plt.xlim([0, 1]);
复制代码

output_104_0.png

9 测试集预测

sample_text = "电源插头应按规定正确接线"
复制代码
encoded_text = tokenizer.encode_plus(
  sample_text,
  max_length=MAX_LEN,
  add_special_tokens=True,
  return_token_type_ids=False,
  pad_to_max_length=True,
  return_attention_mask=True,
  return_tensors='pt',
)
复制代码
input_ids = encoded_text['input_ids'].to(device)
attention_mask = encoded_text['attention_mask'].to(device)

output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)

print(f'Sample text: {sample_text}')
print(f'Danger label  : {class_names[prediction]}')
复制代码
Sample text: 电源插头应按规定正确接线
Danger label  : 1
复制代码

猜你喜欢

转载自juejin.im/post/7132516661847916581