sklearn学习--读取mysql数据源进行训练样本和预测文本分类

# coding=utf-8
import re
import pandas as pd
import string
import MySQLdb
import jieba

from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from sklearn.svm import LinearSVC

#jieba分词 
def jieba_tokenizer(x): return jieba.cut(x,cut_all=True)
def partition(x): return x
def filter_html(s):
    d = re.compile(r'<[^>]+>',re.S)
    return d.sub('',s)

#链接mysql数据库
conn=MySQLdb.connect(host='localhost',user='root',passwd='',db='article',port=3306,charset="utf8")
cursor =conn.cursor()
cursor.execute("SET NAMES utf8")

#训练数据样本
data_ret = pd.DataFrame()
for i in range(0,5):
    sql = "SELECT a.id,a.title,a.classid,b.artcontent FROM article a,article_txt b WHERE a.id=b.aid AND b.artcontent IS NOT NULL AND a.id>100 ORDER BY a.id ASC LIMIT "+str(i*1000)+",1000"
    #print sql
    ret = pd.read_sql_query(sql, conn)   
    data_ret = data_ret.append(ret)

Score = data_ret['classid']
data_ret['artcontent'] = [filter_html(msg) for msg in data_ret['artcontent']] 

X_train = data_ret['artcontent']
Y_train = Score.map(partition)

corpus = []
for txt in X_train:
    corpus.append(' '.join(jieba_tokenizer(txt)))
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(corpus)     
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
clf = LinearSVC().fit(X_train_tfidf, Y_train)


#可以把clf持久化

#测试数据 预测分类
test_set = []
test_txt_data = pd.read_sql_query("SELECT a.id,a.title,a.classid,b.artcontent FROM article a,article_txt b WHERE a.id=b.aid AND b.artcontent IS NOT NULL AND a.id<50 ORDER BY a.id ASC", conn)
X_test = [filter_html(msg) for msg in test_txt_data['artcontent']] 

for text in X_test:
    text=' '.join(jieba_tokenizer(text))
    test_set.append(text)

        
X_new_counts = count_vect.transform(test_set)
X_test_tfidf = tfidf_transformer.transform(X_new_counts)

result = dict()
result = clf.predict(X_test_tfidf)

for i in range(0, len(result)):
    print "ID:"+str(test_txt_data['id'][i])+" -> classid:"+str(result[i])


cursor.close()
conn.close()

猜你喜欢

转载自strayly.iteye.com/blog/2317526