自然语言处理(NLP):07 fastText训练中文模型-文本分类

fastText 另外两种安装方式

conda install 方式:速度慢

https://anaconda.org/conda-forge/fasttext

windows 版本下可以通过whl安装(fasttext‑0.9.1‑cp36‑cp36m‑win32.whl) ,windows 下可以使用这个安装

https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext

fastText 训练

import fastText
import fastText
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix,precision_recall_fscore_support
# 训练
'''

dtrain.txt 和dtest.txt 数据格式 如下:

__label__2 中新网 日电 日前 上海 国际
__label__0 两人 被捕 警方 指控 非法
__label__3 中旬 航渡 过程 美军 第一
__label__1 强强 联手 背后 品牌 用户 双赢
'''
model = fastText.train_supervised(
    '../data/dtrain.txt',
    lr=0.1,
    dim=200,
    epoch=50,
    neg=5,
    wordNgrams=2,
    label="__label__"
    )
# 预测
result = model.test('../data/dtest.txt')
print('y_pred = ',y_pred)
# 保存model
model_path = '../model/fastText_model.pkl'
model.save_model(model_path)
# 计算分类的metrics
#绘制precision、recall、f1-score、support报告表
def eval_model(y_true, y_pred, labels):
    # 计算每个分类的Precision, Recall, f1, support
    p, r, f1, s = precision_recall_fscore_support(y_true, y_pred)
    # 计算总体的平均Precision, Recall, f1, support
    tot_p = np.average(p, weights=s)
    tot_r = np.average(r, weights=s)
    tot_f1 = np.average(f1, weights=s)
    tot_s = np.sum(s)
    res1 = pd.DataFrame({
        u'Label': labels,
        u'Precision': p,
        u'Recall': r,
        u'F1': f1,
        u'Support': s
    })
    res2 = pd.DataFrame({
        u'Label': ['总体'],
        u'Precision': [tot_p],
        u'Recall': [tot_r],
        u'F1': [tot_f1],
        u'Support': [tot_s]
    })
    res2.index = [999]
    res = pd.concat([res1, res2])
    return res[['Label', 'Precision', 'Recall', 'F1', 'Support']]

cate_dic = {'entertainment': 0, 'technology': 1, 'sports': 2, 'military': 3, 'car': 4}
dict_cate = dict(('__label__{}'.format(v),k) for k,v in cate_dic.items())
y_true= []
y_pred = []
with open('../data/dtest.txt','r',encoding='utf-8') as f:
    for line in f.readlines():
        line = line.strip()
        splits = line.split(" ")
        label = splits[0]
        words = [" ".join(splits[1:])]
        label = dict_cate[label]
        y_true.append(label)
        y_pred_results = clf.predict(words)[0][0][0]
        y_pred.append(dict_cate[y_pred_results])
print("y_true = ",y_true[:5])
print("y_pred = ",y_pred[:5])
print('y_true length = ',len(y_true))
print('y_pred length = ',len(y_pred))

print('keys = ',list(cate_dic.keys()))
y_true =  ['sports', 'car', 'car', 'technology', 'entertainment']
y_pred =  ['sports', 'car', 'car', 'technology', 'entertainment']
y_true length =  87581
y_pred length =  87581
keys =  ['entertainment', 'technology', 'sports', 'military', 'car']
eval_model(y_true,y_pred,list(cate_dic.keys()))
Label Precision Recall F1 Support
0 entertainment 0.934803 0.827857 0.878086 8400
1 technology 0.906027 0.923472 0.914666 26696
2 sports 0.881885 0.911727 0.896558 11555
3 military 0.943886 0.931749 0.937778 22476
4 car 0.857226 0.873252 0.865165 18454
999 总体 0.905035 0.904294 0.904270 87581

模拟在线预测

# 加载模型
model_path = '../model/fastText_model.pkl'
clf = fastText.load_model(model_path)
cate_dic = {'entertainment': 0, 'technology': 1, 'sports': 2, 'military': 3, 'car': 4}
print(cate_dic)

dict_cate = dict(('__label__{}'.format(v),k) for k,v in cate_dic.items()) 
print(dict_cate)
{'entertainment': 0, 'technology': 1, 'sports': 2, 'military': 3, 'car': 4}
{'__label__0': 'entertainment', '__label__1': 'technology', '__label__2': 'sports', '__label__3': 'military', '__label__4': 'car'}
  • 预测案例1-汽车类
    摘自今日头条: https://www.toutiao.com/a6714271125473346055/
import jieba
text = "奥迪A3、宝马1系和奔驰A级一直纠缠不休的三个冤家"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)

# predict
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words =  ['奥迪', 'A3', '、', '宝马', '1', '系', '和', '奔驰', 'A', '级', '一直', '纠缠', '不休', '的', '三个', '冤家']
y_pred results =  car
  • 预测案例2-军事类
    摘自今日头条新闻: https://www.toutiao.com/a6714188329937535496/
import jieba
text = "谁说文物只能躺在博物馆,想买一架梦想中的战斗机开着兜风吗?"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
# predict
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words =  ['谁', '说', '文物', '只能', '躺', '在', '博物馆', ',', '想', '买', '一架', '梦想', '中', '的', '战斗机', '开着', '兜风', '吗', '?']
y_pred results =  military
  • 预测案例3-娱乐类
    我们从 今日头条: https://www.toutiao.com/a6689675139333751299/ 拷贝标题来进行预测
import jieba
text = "陈晓旭:从完美林黛玉到身家过亿后剃度出家,她戏里戏外都是传奇"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words =  ['陈晓旭', ':', '从', '完美', '林黛玉', '到', '身家', '过', '亿后', '剃度', '出家', ',', '她', '戏里', '戏外', '都', '是', '传奇']
y_pred results =  entertainment
  • 预测案例4-体育类
    摘自今日头条:https://www.toutiao.com/a6714266792253981192/
import jieba
text = "男女有别!国乒主力参加马来西亚T2联赛 男队站着吃自助女队吃桌餐"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words =  ['男女有别', '!', '国乒', '主力', '参加', '马来西亚', 'T2', '联赛', ' ', '男队', '站', '着', '吃', '自助', '女队', '吃', '桌餐']
y_pred results =  sports
  • 预测案例5-科技类
import jieba
text = "摩托罗拉One Macro将是最新一款Android One智能手机"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words =  ['摩托罗拉', 'One', ' ', 'Macro', '将', '是', '最新', '一款', 'Android', ' ', 'One', '智能手机']
y_pred results =  technology

Flask Web 服务在线预测

http://127.0.0.1:5000/v1/p?q=xxxxx

其中: q 是要预测的样本

# -*- coding: UTF-8 -*-
import jieba
import fastText
from flask import Flask
from flask import request

app = Flask(__name__)
model_path = '../model/fastText_model.pkl'
clf = fastText.load_model(model_path)
cate_dic = {'entertainment': 0, 'technology': 1, 'sports': 2, 'military': 3, 'car': 4}
dict_cate = dict(('__label__{}'.format(v), k) for k, v in cate_dic.items())
print(dict_cate)


@app.route('/')
def hello_world():
    return 'Hello World!'


@app.route('/v1/p', methods=['POST', 'GET'])
def predict():
    if request.method == 'POST':
        q = request.form['q']
    else:
        q = request.args.get('q', '')
        print('q = ', q)

    print('input data:', q)
    words = [word for word in jieba.lcut(q)]
    print('words = ', words)
    data = " ".join(words)
    results = clf.predict([data])
    y_pred = results[0][0][0]
    return dict_cate[y_pred]

if __name__ == '__main__':
    app.run()
发布了267 篇原创文章 · 获赞 66 · 访问量 43万+

猜你喜欢

转载自blog.csdn.net/shenfuli/article/details/98882655