# coding=utf-8 import re import pandas as pd import string import MySQLdb import jieba from sklearn.feature_extraction.text import TfidfTransformer from sklearn.feature_extraction.text import CountVectorizer from sklearn.cross_validation import train_test_split from sklearn.metrics import confusion_matrix from sklearn import metrics from sklearn.metrics import roc_curve, auc from sklearn.svm import LinearSVC #jieba分词 def jieba_tokenizer(x): return jieba.cut(x,cut_all=True) def partition(x): return x def filter_html(s): d = re.compile(r'<[^>]+>',re.S) return d.sub('',s) #链接mysql数据库 conn=MySQLdb.connect(host='localhost',user='root',passwd='',db='article',port=3306,charset="utf8") cursor =conn.cursor() cursor.execute("SET NAMES utf8") #训练数据样本 data_ret = pd.DataFrame() for i in range(0,5): sql = "SELECT a.id,a.title,a.classid,b.artcontent FROM article a,article_txt b WHERE a.id=b.aid AND b.artcontent IS NOT NULL AND a.id>100 ORDER BY a.id ASC LIMIT "+str(i*1000)+",1000" #print sql ret = pd.read_sql_query(sql, conn) data_ret = data_ret.append(ret) Score = data_ret['classid'] data_ret['artcontent'] = [filter_html(msg) for msg in data_ret['artcontent']] X_train = data_ret['artcontent'] Y_train = Score.map(partition) corpus = [] for txt in X_train: corpus.append(' '.join(jieba_tokenizer(txt))) count_vect = CountVectorizer() X_train_counts = count_vect.fit_transform(corpus) tfidf_transformer = TfidfTransformer() X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts) clf = LinearSVC().fit(X_train_tfidf, Y_train) #可以把clf持久化 #测试数据 预测分类 test_set = [] test_txt_data = pd.read_sql_query("SELECT a.id,a.title,a.classid,b.artcontent FROM article a,article_txt b WHERE a.id=b.aid AND b.artcontent IS NOT NULL AND a.id<50 ORDER BY a.id ASC", conn) X_test = [filter_html(msg) for msg in test_txt_data['artcontent']] for text in X_test: text=' '.join(jieba_tokenizer(text)) test_set.append(text) X_new_counts = count_vect.transform(test_set) X_test_tfidf = tfidf_transformer.transform(X_new_counts) result = dict() result = clf.predict(X_test_tfidf) for i in range(0, len(result)): print "ID:"+str(test_txt_data['id'][i])+" -> classid:"+str(result[i]) cursor.close() conn.close()
sklearn学习--读取mysql数据源进行训练样本和预测文本分类
猜你喜欢
转载自strayly.iteye.com/blog/2317526
今日推荐
周排行