fasttext训练模型代码

数据下载
链接: https://pan.baidu.com/s/13g8qi09NXafjJVZXWR2nVQ 密码: rsgl

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# author ChenYongSheng
# date 20201222

import pandas as pd
import jieba

'''数据预处理'''
df = pd.read_csv('data/8qi/xx.csv', header=0)
stopwords = [line.strip() for line in open('data/all/stopwords.txt', encoding='utf-8').readlines()]


def remove_stopwords(text_cut, stopwords):
    result = []
    for word in text_cut:
        if word not in stopwords:
            result.append(word)
    return result


lines = []
test_lines = []
for data in df.itertuples():
    # print(data)
    label = '__label__' + str(data.label)
    text = str(data.text)
    text_cut = jieba.lcut(text)
    text_remove_stop = remove_stopwords(text_cut, stopwords)
    words = ''
    for word in text_remove_stop:
        words = word + ' ' + words
    body = label + ' , ' + words.rstrip(' ')
    if data.Index % 10 == 0:
        test_lines.append(body)
    else:
        lines.append(body)

with open('data/8qi/train.txt', 'w', encoding='utf-8') as f:
    for line in lines:
        f.write(line + '\n')
    f.close()

with open('data/8qi/test.txt', 'w', encoding='utf-8') as f:
    for line in test_lines:
        f.write(line + '\n')
    f.close()
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# author ChenYongSheng
# date 20201222

import fasttext

'''模型训练'''

trainDataFile = 'data/8qi/train.txt'

model = fasttext.train_supervised(trainDataFile, lr=0.1, dim=100, epoch=30, word_ngrams=2, loss='softmax')
model.save_model("model/fasttext_model.bin")


testDataFile = 'data/8qi/test.txt'

model = fasttext.load_model('model/fasttext_model.bin')

result = model.test(testDataFile)
print('测试集上数据量', result[0])
print('测试集上准确率', result[1])
print('测试集上召回率', result[2])

必须是这样的数据格式:__label__分类名(空格)(逗号)(空格)(切词)
__label__安静程度 , 吵不吵 房子 那套 肯德基
__label__安静程度 , 吵
__label__安静程度 , 位置 吵 卧室

如果报错ValueError: data/7期/train.txt cannot be opened for training!
即是数据文件路径包含中文名,改成英文或拼音

# 计算分类的metrics
#绘制precision、recall、f1-score、support报告表
def eval_model(y_true, y_pred, labels):
    # 计算每个分类的Precision, Recall, f1, support
    p, r, f1, s = precision_recall_fscore_support(y_true, y_pred)
    # 计算总体的平均Precision, Recall, f1, support
    tot_p = np.average(p, weights=s)
    tot_r = np.average(r, weights=s)
    tot_f1 = np.average(f1, weights=s)
    tot_s = np.sum(s)
    res1 = pd.DataFrame({
    
    
        u'Label': labels,
        u'Precision': p,
        u'Recall': r,
        u'F1': f1,
        u'Support': s
    })
    res2 = pd.DataFrame({
    
    
        u'Label': ['总体'],
        u'Precision': [tot_p],
        u'Recall': [tot_r],
        u'F1': [tot_f1],
        u'Support': [tot_s]
    })
    res2.index = [99999]
    res = pd.concat([res1, res2])
    return res[['Label', 'Precision', 'Recall', 'F1', 'Support']]

label_dict_file = 'data/8qi/label_dict.xls'
cate_dic = get_label_dict(label_dict_file)
dict_cate = dict(('__label__{}'.format(k),v) for k,v in cate_dic.items())
y_true= []
y_pred = []
with open('data/8qi/test.txt','r',encoding='utf-8') as f:
    for line in f.readlines():
        line = line.strip()
        splits = line.split(" ")
        label = splits[0]
        words = [" ".join(splits[1:])]
        label = dict_cate[label]
        y_true.append(label)
        y_pred_results = model.predict(words)[0][0][0]
        y_pred.append(dict_cate[y_pred_results])
print("y_true = ",y_true[:5])
print("y_pred = ",y_pred[:5])
print('y_true length = ',len(y_true))
print('y_pred length = ',len(y_pred))

print('keys = ',list(cate_dic.keys()))

eval_model(y_true,y_pred,list(cate_dic.keys()))

cate_dic
{‘问候语’: ‘GREETINGS’,…}
keys = [‘问候语’, …]

Label  Precision    Recall        F1  Support0         问候语   0.941176  0.941176  0.941176       3499999      总体   0.928100  0.922147  0.920730     1323

import jieba

text = "这个房子安静吗"
words = [word for word in jieba.lcut(text)]
print('words = ', words)
data = " ".join(words)

# predict
results = model.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ", str(y_pred).replace('__label__', ''), dict_cate[y_pred])

猜你喜欢

转载自blog.csdn.net/qq236237606/article/details/111572554