Python开发 之 Sklearn的模型 和 CountVectorizer 、Transformer 保存 和 使用

1、简述

如果用到TF-IDF,sklearn中经常会用CountVectorizer与TfidfTransformer两个类。我们总是需要保存TF-IDF的词典,然后计算测试集的TF-IDF,这里要注意sklearn中保存有两种方法:pickle与joblib。这里,我们可以用pickle保存特征,用joblib保存模型。

2、 CountVectorizer 和 Transformer保存和加载

2.1、TF-IDF词典的保存

 train_content = segmentWord(X_train)
 test_content = segmentWord(X_test)
 # replace 必须加,保存训练集的特征
 vectorizer = CountVectorizer(decode_error="replace")
 tfidftransformer = TfidfTransformer()
 # 注意在训练的时候必须用vectorizer.fit_transform、tfidftransformer.fit_transform
 # 在预测的时候必须用vectorizer.transform、tfidftransformer.transform
 vec_train = vectorizer.fit_transform(train_content)
 tfidf = tfidftransformer.fit_transform(vec_train)

 # 保存经过fit的vectorizer 与 经过fit的tfidftransformer,预测时使用
 feature_path = 'models/feature.pkl'
 with open(feature_path, 'wb') as fw:
     pickle.dump(vectorizer.vocabulary_, fw)
 
 tfidftransformer_path = 'models/tfidftransformer.pkl'
 with open(tfidftransformer_path, 'wb') as fw:
     pickle.dump(tfidftransformer, fw)

注意:vectorizer 与tfidftransformer都要保存,而且只能 fit_transform 之后保存,表示vectorizer 与tfidftransformer已经用训练集训练好了。

2.2、TF-IDF加载,测试新数据

# 加载特征
feature_path = 'models/feature.pkl'
loaded_vec = CountVectorizer(decode_error="replace", vocabulary=pickle.load(open(feature_path, "rb")))
# 加载TfidfTransformer
tfidftransformer_path = 'models/tfidftransformer.pkl'
tfidftransformer = pickle.load(open(tfidftransformer_path, "rb"))
#测试用transform,表示测试数据,为list
test_tfidf = tfidftransformer.transform(loaded_vec.transform(test_content))

3、模型的保存和加载

3.1、模型的保存

# clf_model为生成的模型,利用joblib的dump保存
    clf_model = trainModel()  
    joblib.dump(clf_model, "model_"+path) 

3.2、模型的加载

# clf_model为模型,利用joblib的load加载
    clf_model = joblib.load(model_path)

4、例子

举了一个简单的例子,可以学一下这里到底怎么使用

"""
Author:沙振宇
Time:20191112
Info:简单的情绪识别
"""
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn import svm
from sklearn.metrics import accuracy_score
import joblib
import os
import jieba
import datetime
import warnings # 忽略警告
import pickle
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn", lineno=196)

m_list_allText = []# 全部内容(包含重复标签,顺序)
m_list_allL4ID = []# 全部标签(包含重复标签,顺序)
m_list_allLabel = [] # 模型全部标签(不包含重复标签,顺序)
m_map_id_score = {} # id对应的分值

# 读取文件里面数据,获取标签和内容
def getFile(filename, count = -1):
    with open(filename, 'r' ,encoding='utf-8') as fp:
        global m_list_allL4ID,m_list_allText
        m_list_allL4ID = []
        m_list_allText = []
        tmp_text = []
        m_file_text = fp.readlines()
        tmp_lines = len(m_file_text)
        for i in range(tmp_lines):
            text = m_file_text[i]
            if ":" in text:
                L4ID = text.split(":")[-2]
                Msg = text.split(":")[-1]
                m_list_allL4ID.append(L4ID)
                m_list_allText.append(Msg)
                if L4ID not in m_list_allLabel:
                    m_list_allLabel.append(L4ID)
                    tmp_text = []
                tmp_text.append(Msg)

# jieba分词
def jiabaToVector(list, isTest, isTFIDF = False):
    tmp_list = []
    for sentence in list:
        tmp_list.append(" ".join(jieba.cut(sentence.strip())))
    # 利用TFIDF生成词向量
    transformer = TfidfTransformer()
    if isTest:
        if isTFIDF:
            tfidf = transformer.fit_transform(vectorizer.transform(tmp_list))
        else:
            tfidf = vectorizer.transform(tmp_list)
    else:
        if isTFIDF:
            tfidf = transformer.fit_transform(vectorizer.fit_transform(tmp_list))
        else:
            tfidf = vectorizer.fit_transform(tmp_list)
    return tfidf

# 创建默认的参数
def predict_4(X, Y):
    clf = svm.LinearSVC()
    clf = clf.fit(X, Y)
    return clf

# 将L4转换为整形
def L4ToInt(m_label_l4):
    m_label_l4New = []
    for i in range(len(m_label_l4)):
        m_label_l4New.append(int(m_label_l4[i][1:]))
    # print("m_label_l4New:",m_label_l4New)
    return m_label_l4New

# 训练SVM模型
def trainSVM(path, linecount = -1):
    getFile(path, linecount)

    vectorizer = CountVectorizer(decode_error="replace")  # 全局向量, replace 必须加,保存训练集的特征
    vector_train = jiabaToVector(m_list_allText, False, True)# 生成训练向量

    lenall = len(m_list_allText)# 数据大小
    print("总集大小:", lenall)# print("总集大小:", lenall)

    startT_Train = datetime.datetime.now()# 训练
    clf = predict_4(vector_train, m_list_allL4ID)
    endT_Train = datetime.datetime.now()
    print("训练Time:", (endT_Train - startT_Train).microseconds)
    return clf,vectorizer

# 查看完全匹配
def completeLabelDataMatch(path , query):
    outList = {}
    file_train = os.path.join(path)
    with open(file_train, 'r', encoding='UTF-8')as fp:
        textlist = fp.readlines()
        for text in textlist:
            if ":" in text:
                conditionId = text.split(":")[-2]
                Msg = text.split(":")[-1]
                message = Msg.strip("\n")
                if query == message:
                    outList["conditionId"] = conditionId
                    outList["Score"] = 1
                    print("Complete labelData match work: %s:%s"%(conditionId,message))
                    return outList
    return False

# 查看query的分值
def SVMMain(path, clf, query, score):
    outList = completeLabelDataMatch(path, query)
    if outList:
        print("outList[\"conditionId\"]:", outList["conditionId"])
        print("outList[\"Score\"]:", outList["Score"])
    else:
        outList = useSVM(clf, query)

    if outList["Score"] > score:
        return emotionAnalysis(outList["conditionId"])
    else:
        return "normal"

# 运用SVM模型
def useSVM(clf, query):
    outList = {}
    querylist = []
    querylist.append(query)
    vector_test = jiabaToVector(querylist, True, True) # 生成测试向量
    startT = datetime.datetime.now()
    percent = clf.decision_function(vector_test)
    scorelist = []
    if len(percent[0]):
        scorelist = percent[0]

    if len(m_list_allLabel) == len(scorelist):
        for i in range(len(scorelist)):
            m_map_id_score[m_list_allLabel[i]] = scorelist[i]

    pVallist = sorted(scorelist,reverse=True)
    percent = max(pVallist)
    conditionID = ""
    for item in range (len(m_map_id_score)):
        if m_map_id_score[m_list_allLabel[item]] == percent:
            conditionID = m_list_allLabel[item]

    endT = datetime.datetime.now()
    print("测试Time:", (endT - startT).microseconds)
    outList["conditionId"] = conditionID
    outList["Score"] = percent
    print("outList[\"conditionId\"]:", outList["conditionId"])
    print("outList[\"Score\"]:", outList["Score"])
    return outList

# 情绪识别
def emotionAnalysis(label):
    negtiveId = ['4000447','4000448','4000453','4000449','4000450','4000451', '4000452','4000454','4000459','4000458','4002227','4000461','4000460','4000465','4000464','4000803','4000468'] # 消极
    positiveId = ['4000439','4000440','4000441','4000442','4000462','4000467','4000469','4000496','4000497'] # 积极
    print("negtiveId:",negtiveId)
    print("positiveId:",positiveId)
    if label in negtiveId:
        return "negtive"
    elif label in positiveId:
        return "positive"
    else:
        return "normal"

# 保存模型和特征
def saveModel(path):
    clf, vectorizer = trainSVM(path, -1)
    feature_path = path
    if "/" in path:
        joblib.dump(clf, "model_" + path.split("/")[1])
        if ".txt" in path:
            feature_path = 'feature_' + path.split("/")[1].split(".txt")[0] + '.pkl'
    else:
        joblib.dump(clf, "model_" + path)
        if ".txt" in path:
            feature_path = 'feature_' + path.split(".txt")[0] + '.pkl'
    print("模型已经保存,开始保存特征")

    with open(feature_path, 'wb') as fw:
        pickle.dump(vectorizer.vocabulary_, fw)
    print("特征已经保存。。。")

# 使用模型和特征
def useModel(model_path, feature_path):
    # 加载模型
    clf = joblib.load(model_path)
    # 加载特征
    loaded_vec = CountVectorizer(decode_error = "replace", vocabulary = pickle.load(open(feature_path, "rb")))
    return clf, loaded_vec

if __name__ =="__main__":
    path = "../rg_train_20171230_1000008.txt"
    # 保存模型
    # saveModel(path)

    # 加载模型
    global vectorizer
    clf, vectorizer = useModel("model_rg_train_20171230_1000008.txt","feature_rg_train_20171230_1000008.pkl")

    source = "我很开心"
    source = source.replace("\r", "")
    source = source.replace("\n", "")
    source = source.lower()
    source = source[0:256]
    print("开始匹配")
    result = SVMMain(path , clf, source, 0.6)

    if result == "normal":
        print("中性情绪")
    elif result == "negtive":
        print("负向情绪")
    elif result == "positive":
        print("正向情绪")

发布了264 篇原创文章 · 获赞 691 · 访问量 204万+

猜你喜欢

转载自blog.csdn.net/u014597198/article/details/103037709
今日推荐