# -*- coding: UTF-8 -*- import jieba import os import random from sklearn.naive_bayes import MultinomialNB from matplotlib import pyplot as plt def TextProcessing(folder_path, test_size=0.2): """ 处理路径下的所有文本,并划分训练集和测试集,在训练集上,生成无重复词语集并按出现次数从大到小排序 :param folder_path:文本路径 :param test_size:测试数据占比 :return:all_words_list-词汇表(由训练集得到),train_data_list-训练集词汇表, test_data_list-测试集词汇表,train_class_list-训练集标签列表,test_class_list-测试集标签列表 """ folder_list = os.listdir(folder_path) data_list = [] class_list = [] for folder in folder_list: new_folder_path = os.path.join(folder_path, folder) files = os.listdir(new_folder_path) j = 1 for file in files: if j > 100: break with open(os.path.join(new_folder_path, file), 'r', encoding='utf-8') as f: raw = f.read() word_cut = jieba.cut(raw, cut_all=False) word_list = list(word_cut) data_list.append(word_list) class_list.append(folder) j += 1 data_class_list = list(zip(data_list, class_list)) random.shuffle(data_class_list) index = int(len(data_class_list)*test_size) + 1 train_list = data_class_list[index:] test_list = data_class_list[:index] train_data_list, train_class_list = zip(*train_list) test_data_list, test_class_list = zip(*test_list) all_words_dict = {} for word_list in train_data_list: for word in word_list: if word in all_words_dict.keys(): all_words_dict[word] += 1 else: all_words_dict[word] = 1 all_words_tuple_list = sorted(all_words_dict.items(),key=lambda f:f[1],reverse=True) all_words_list, all_words_nums = zip(*all_words_tuple_list) all_words_list = list(all_words_list) return all_words_list,train_data_list,test_data_list,train_class_list,test_class_list def MakeWordsSet(words_file): """ 生成词汇表 :param words_file:词汇文本 :return:词汇列表 """ words_set = set() with open(words_file,'r',encoding='utf-8') as f: for line in f.readlines(): word = line.strip() if len(word) > 0: words_set.add(word) return words_set def words_dict(all_words_list,deleteN,stopwords_set = set()): """ 按照all_words_list从deleteN到最后,如果单词没出现在stopwords_set中且满足条件,则作为特征词语 :param all_words_list:排序过的训练集词语集 :param deleteN:舍弃训练集中前deleteN个词语 :param stopwords_set:结束语,即出现在这里的词不能作为特征词 :return:特征集 """ feature_words = [] n = 1 for t in range(deleteN,len(all_words_list),1): if n > 1000: break if not all_words_list[t].isdigit() and all_words_list[t] not in stopwords_set and 1< len(all_words_list[t]) < 5: feature_words.append(all_words_list[t]) n += 1 return feature_words def TextFeatures(train_data_list,test_data_list,feature_words): """ 将训练集和测试集由词语变成0/1组成的序列,词语出现在feature_words中则为1,反之为0 :param train_data_list: 训练集 :param test_data_list: 测试集 :param feature_words: 特征集 :return: train_feature_list-训练集特征序列,test_feature_list-测试集特征序列 """ def text_features(text,feature_words): text_words = set(text) features = [1 if word in text_words else 0 for word in feature_words] return features train_feature_list = [text_features(text, feature_words) for text in train_data_list] test_feature_list = [text_features(text, feature_words) for text in test_data_list] return train_feature_list, test_feature_list def TextClassifier(train_feature_list, test_feature_list, train_class_list, test_class_list): classifier = MultinomialNB().fit(train_feature_list,train_class_list) test_accuracy = classifier.score(test_feature_list,test_class_list) return test_accuracy if __name__ == '__main__': folder_path = './SogouC/Sample' all_words_list, train_data_list, test_data_list, train_class_list, test_class_list = TextProcessing(folder_path, test_size=0.2) stopwords_file = './stopwords_cn.txt' stopwords_set = MakeWordsSet(stopwords_file) test_accuracy_list = [] deleteNs = range(0,1000,20) for deleteN in deleteNs: feature_words = words_dict(all_words_list, deleteN, stopwords_set) train_feature_list, test_feature_list = TextFeatures(train_data_list,test_data_list,feature_words) test_accuracy = TextClassifier(train_feature_list, test_feature_list, train_class_list, test_class_list) test_accuracy_list.append(test_accuracy) fig = plt.figure() f1 = fig.add_subplot(111) f1.plot(deleteNs,test_accuracy_list,) f1.set_title('Relationship of deleteNs and test_accuracy') f1.set_xlabel('deleteNs') f1.set_ylabel('test_accuracy') plt.show()
测试结果:
参考:https://blog.csdn.net/c406495762/article/details/77500679