sklearn学习--文本分类多分类应用

#!/usr/bin/env python
# coding=utf-8
import sys
import jieba
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import MultiLabelBinarizer
import MySQLdb
import pandas as pd
import re 
import numpy as np

def jieba_tokenizer(x): return jieba.cut(x,cut_all=True)
def partition(x): return x  
def filter_html(s):  
    d = re.compile(r'<[^>]+>',re.S)  
    s = d.sub('',s)
    return s  
def gbk_utf8(s):
    s = s.decode('gbk',"ignore").encode('utf8')
    return s
#链接mysql数据库  
conn=MySQLdb.connect(host='localhost',user='root',passwd='',db='mydb',port=3306,charset="gbk",use_unicode=False)  
cursor =conn.cursor()  
cursor.execute("SET NAMES GBK")  


#训练数据样本  
data_ret = pd.DataFrame()  

sql = "SELECT ID, title,classid, content FROM t_reprint article WHERE ID<1000 ORDER BY a.ID ASC LIMIT 0,1000"  
#print sql  
cursor.execute(sql)

txt_ret =  []  
class_ret = []
id_ret = [] 
for row in cursor.fetchall():  
    content = filter_html(gbk_utf8(row[3]))   
    txt_ret.append(content) 
    class_s = gbk_utf8(row[2])
    class_l = class_s.split(",")
    class_ret.append(class_l) 
    id_ret.append(row[0])
    
txt_ret = txt_ret

  
X_train = txt_ret 
Y_train = class_ret


classifier = Pipeline([
    ('counter', CountVectorizer(tokenizer=jieba_tokenizer)),
    ('tfidf', TfidfTransformer()),
    ('clf', OneVsRestClassifier(LinearSVC())),
])
mlb = MultiLabelBinarizer()
Y_train = mlb.fit_transform(Y_train)


classifier.fit(X_train, Y_train)



#target_names=['100','102','103','104','105','106','107','108','109','110','111','112','113','114','115','116','117','118','119','120','121','122','123','124','125','126','127','128','129','130','131', '132','133','134']
#测试数据
test_txt_set = []  
sql = "SELECT ID, title,classid, content FROM article  WHERE ID>1000 ORDER BY ID DESC LIMIT 10 "
cursor.execute(sql)
test_id_ret = [] 

for row in cursor.fetchall():  
    test_txt_set.append(filter_html(gbk_utf8(row[3]))) 
    test_id_ret.append(row[0])
X_test = test_txt_set


prediction = classifier.predict(X_test)

result = mlb.inverse_transform(prediction)
#展示结果
for i, label1 in enumerate(result):
    classstr = ''
    for j, label2 in enumerate(label1):
        classstr+=str(label2)+","
    print "ID:"+str(test_id_ret[i])+" =>class:"+classstr

猜你喜欢

转载自strayly.iteye.com/blog/2321066