sklearn+python:朴素贝叶斯及文本分类

版权声明:本文为博主原创文章,转载请标明原始博文地址。 https://blog.csdn.net/yuanlulu/article/details/82558938

朴素贝叶斯

贝叶斯定理用来计算条件概率,即:

image

然后进行一种朴素(naive)的假设-每对特征之间都相互独立:

image

在给定的输入中 P(x_1, \dots, x_n) 是一个常量,我们使用下面的分类规则:

image

可以使用最大后验概率(Maximum A Posteriori, MAP) 来估计 P(y) 和 P(x_i | y) ; 前者是训练集中类别 y 的相对频率。

各种各样的的朴素贝叶斯分类器的差异大部分来自于处理 P(x_i | y) 分布时的所做的假设不同。

P(x_i | y)在有足够多数据时,也可以从数据集中统计出来,或者使用各种分布模型估计。

概率分布

描述连续随机变量概率分布的函数称为概率密度函数。典型的概率密度函数是高斯分布函数。

描述离散随机变量概率分布的函数叫概率质量函数。典型的概率质量函数是多项式分布函数。可用于文本分类或者垃圾邮件检测。

扫描二维码关注公众号,回复: 3149298 查看本文章

这两种分布函数统称为概率分布函数

伯努利高斯和多项式分布类似,但是它明确地惩罚类 y 中没有出现作为预测因子的特征 i ,而多项分布分布朴素贝叶斯只是简单地忽略没出现的特征。在文本分类的例子中,词频向量(word occurrence vectors)的 BernoulliNB 可能在一些数据集上可能表现得更好,特别是那些更短的文档。 如果时间允许,建议对两个模型都进行评估。

代码:使用多项式分布的朴素贝叶斯进行文本分类

下面的代码主要来自黄永昌的书籍,我整理添加了部分内容。代码中需要用的的数据库可以从这里下载:https://download.csdn.net/download/rendo/10287144

from time import time
from sklearn.datasets import load_files
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

# 步骤1:载入新闻语料文档
print("loading train dataset ...")
t = time()
# load file 专门用于载入分类的文档,每个分类一个单独的目录,目录名就是类名
news_train = load_files('D:/项目/机器学习和机器视觉/scikit-learn机器学习-源代码/code/datasets/mlcomp/379/train')
print("summary: {0} documents in {1} categories.".format(
    len(news_train.data), len(news_train.target_names)))
# news_train.target是长度为13180的一维向量,每个值代表相应文章的分类id
print('news_categories_names:\n{}, \nlen(target):{}, target:{}'.format(news_train.target_names,
                                                                       len(news_train.target), news_train.target))
print("done in {0} seconds\n".format(round(time() - t, 2)))

# 步骤2:将文档数据转化为TF-IDF向量
print("vectorizing train dataset ...")
t = time()
vectorizer = TfidfVectorizer(encoding='latin-1')
X_train = vectorizer.fit_transform((d for d in news_train.data))
print("n_samples: %d, n_features: %d" % X_train.shape)
# X_train每一行代表一篇文档,每个成员表示一个词的TF-IDF值,表示这个词对这个文章的重要性。
# X_train的形状是13180X130274
print("number of non-zero features in sample [{0}]: {1}".format(
    news_train.filenames[0], X_train[0].getnnz()))
print("done in {0} seconds\n".format(round(time() - t, 2)))

# 步骤3:使用多项式分布的朴素贝叶斯算法训练
print("traning models ...".format(time() - t))
t = time()
y_train = news_train.target
clf = MultinomialNB(alpha=0.0001)
clf.fit(X_train, y_train)
train_score = clf.score(X_train, y_train)
print("train score: {0}".format(train_score))
print("done in {0} seconds\n".format(round(time() - t, 2)))

# 步骤4:加载测试数据集
print("loading test dataset ...")
t = time()
news_test = load_files('D:/项目/机器学习和机器视觉/scikit-learn机器学习-源代码/code/datasets/mlcomp/379/test')
print("summary: {0} documents in {1} categories.".format(
    len(news_test.data), len(news_test.target_names)))
print("done in {0} seconds\n".format(round(time() - t, 2)))

# 步骤5:把测试数据集向量化
print("vectorizing test dataset ...")
t = time()
# 注意这里调用的是transform而非上面的fit_transform。因为上面已经把数据统计好了
X_test = vectorizer.transform((d for d in news_test.data))
y_test = news_test.target
print("n_samples: %d, n_features: %d" % X_test.shape)
print("number of non-zero features in sample [{0}]: {1}".format(
    news_test.filenames[0], X_test[0].getnnz()))
print("done in %fs\n" % (time() - t))

# 步骤6:使用测试数据集测试。测试第一篇文章
print("predict for {} ...".format(news_test.filenames[0]))
pred = clf.predict(X_test[0])
print("predict: {0} is in category {1}".format(
    news_test.filenames[0], news_test.target_names[pred[0]]))
print("actually: {0} is in category {1}\n".format(
    news_test.filenames[0], news_test.target_names[news_test.target[0]]))

# 步骤7:评估算法的预测效果
print("predicting test dataset ...")
t = time()
pred = clf.predict(X_test)
print("done in %fs" % (time() - t))
print("classification report on test set for classifier:")
print(clf)
print(classification_report(y_test, pred,
                            target_names=news_test.target_names))

# 步骤8:生成混淆矩阵
cm = confusion_matrix(y_test, pred)
print("confusion matrix:")
print(cm)

输出为:

loading train dataset ...
summary: 13180 documents in 20 categories.
news_categories_names:
['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'], 
len(target):13180, target:[18 13  1 ... 14 15  4]
done in 1.97 seconds

vectorizing train dataset ...
n_samples: 13180, n_features: 130274
number of non-zero features in sample [D:/项目/机器学习和机器视觉/scikit-learn机器学习-源代码/code/datasets/mlcomp/379/train\talk.politics.misc\17860-178992]: 108
done in 5.35 seconds

traning models ...
train score: 0.9978755690440061
done in 0.37 seconds

loading test dataset ...
summary: 5648 documents in 20 categories.
done in 0.86 seconds

vectorizing test dataset ...
n_samples: 5648, n_features: 130274
number of non-zero features in sample [D:/项目/机器学习和机器视觉/scikit-learn机器学习-源代码/code/datasets/mlcomp/379/test\rec.autos\7429-103268]: 61
done in 2.129905s

predict for D:/项目/机器学习和机器视觉/scikit-learn机器学习-源代码/code/datasets/mlcomp/379/test\rec.autos\7429-103268 ...
predict: D:/项目/机器学习和机器视觉/scikit-learn机器学习-源代码/code/datasets/mlcomp/379/test\rec.autos\7429-103268 is in category rec.autos
actually: D:/项目/机器学习和机器视觉/scikit-learn机器学习-源代码/code/datasets/mlcomp/379/test\rec.autos\7429-103268 is in category rec.autos

predicting test dataset ...
done in 0.041888s
classification report on test set for classifier:
MultinomialNB(alpha=0.0001, class_prior=None, fit_prior=True)
                          precision    recall  f1-score   support

             alt.atheism       0.90      0.91      0.91       245
           comp.graphics       0.80      0.90      0.85       298
 comp.os.ms-windows.misc       0.82      0.79      0.80       292
comp.sys.ibm.pc.hardware       0.81      0.80      0.81       301
   comp.sys.mac.hardware       0.90      0.91      0.91       256
          comp.windows.x       0.88      0.88      0.88       297
            misc.forsale       0.87      0.81      0.84       290
               rec.autos       0.92      0.93      0.92       324
         rec.motorcycles       0.96      0.96      0.96       294
      rec.sport.baseball       0.97      0.94      0.96       315
        rec.sport.hockey       0.96      0.99      0.98       302
               sci.crypt       0.95      0.96      0.95       297
         sci.electronics       0.91      0.85      0.88       313
                 sci.med       0.96      0.96      0.96       277
               sci.space       0.94      0.97      0.96       305
  soc.religion.christian       0.93      0.96      0.94       293
      talk.politics.guns       0.91      0.96      0.93       246
   talk.politics.mideast       0.96      0.98      0.97       296
      talk.politics.misc       0.90      0.90      0.90       236
      talk.religion.misc       0.89      0.78      0.83       171

             avg / total       0.91      0.91      0.91      5648

confusion matrix:
[[224   0   0   0   0   0   0   0   0   0   0   0   0   0   2   5   0   0   1  13]
 [  1 267   5   5   2   8   1   1   0   0   0   2   3   2   1   0   0   0   0   0]
 [  1  13 230  24   4  10   5   0   0   0   0   1   2   1   0   0   0   0   1   0]
 [  0   9  21 242   7   2  10   1   0   0   1   1   7   0   0   0   0   0   0   0]
 [  0   1   5   5 233   2   2   2   1   0   0   3   1   0   1   0   0   0   0   0]
 [  0  20   6   3   1 260   0   0   0   2   0   1   0   0   2   0   2   0   0   0]
 [  0   2   5  12   3   1 235  10   2   3   1   0   7   0   2   0   2   1   4   0]
 [  0   1   0   0   1   0   8 300   4   1   0   0   1   2   3   0   2   0   1   0]
 [  0   1   0   0   0   2   2   3 283   0   0   0   1   0   0   0   0   0   1   1]
 [  0   1   1   0   1   2   1   2   0 297   8   1   0   1   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   2   2 298   0   0   0   0   0   0   0   0   0]
 [  0   1   2   0   0   1   1   0   0   0   0 284   2   1   0   0   2   1   2   0]
 [  0  11   3   5   4   2   4   5   1   1   0   4 266   1   4   0   1   0   1   0]
 [  1   1   0   1   0   2   1   0   0   0   0   0   1 266   2   1   0   0   1   0]
 [  0   3   0   0   1   1   0   0   0   0   0   1   0   1 296   0   1   0   1   0]
 [  3   1   0   1   0   0   0   0   0   0   1   0   0   2   1 280   0   1   1   2]
 [  1   0   2   0   0   0   0   0   1   0   0   0   0   0   0   0 236   1   4   1]
 [  1   0   0   0   0   1   0   0   0   0   0   0   0   0   0   3   0 290   1   0]
 [  2   1   0   0   1   1   0   1   0   0   0   0   0   0   0   1  10   7  212  0]
 [ 16   0   0   0   0   0   0   0   0   0   0   0   0   0   0  12   4   1   4 134]]

猜你喜欢

转载自blog.csdn.net/yuanlulu/article/details/82558938