基于weka平台手工实现朴素贝叶斯分类

一、贝叶斯定理

B事件发生后,A事件发生的概率可以如下表示:

p ( A ∣ B ) = p ( A ∩ B ) P ( B ) (1) p(A|B)=\frac{p(A\cap B)}{P(B)}\tag{1} p(AB)=P(B)p(AB)(1)

A事件发生后,B事件发生的概率可以如下表示:

p ( B ∣ A ) = p ( A ∩ B ) P ( A ) (2) p(B|A)=\frac{p(A\cap B)}{P(A)}\tag{2} p(BA)=P(A)p(AB)(2)

二者做比:

P ( A ∣ B ) P ( B ∣ A ) = P ( A ) P ( B ) (3) \frac{P(A|B)}{P(B|A)}=\frac{P(A)}{P(B)}\tag{3} P(BA)P(AB)=P(B)P(A)(3)

P ( B ∣ A ) P(B|A) P(BA) 乘到等式右边后,我们就叨叨了如下贝叶斯定理:

P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) (4) P(A|B)=\frac{P(B|A)P(A)}{P(B)}\tag{4} P(AB)=P(B)P(BA)P(A)(4)

二、贝叶斯分类

将贝叶斯定理的变量名称稍作变换,我们就得到了贝叶斯公式:

P ( c ∣ x ) = P ( x ∣ c ) P ( c ) P ( x ) (5) P(c|\bm{x})=\frac{P(\bm{x}|c)P(c)}{P(\bm{x})}\tag{5} P(cx)=P(x)P(xc)P(c)(5)

其中, P ( c ) P(c) P(c) 表示数据集中 l a b e l label label c c c 类样本的概率, x \bm{x} x 是输入属性, P ( x ) P(\bm{x}) P(x) 表示输入 x \bm{x} x 发生的概率, P ( x ∣ c ) P(\bm{x}|c) P(xc) c c c 发生条件下 x \bm{x} x 发生的概率。

我们通过公式5,来表示我们把输入 x \bm{x} x 分为 c c c 类的概率,这就是贝叶斯分类

进一步理解, P ( c ) P(c) P(c)叫做先验概率 P ( x ∣ c ) P(x|c) P(xc)叫做似然概率,二者相乘,最终的结果就可以很好的表征样本为某个类别的可能性大小。

三、朴素贝叶斯

从上面的式5可以看出,我们如果想要预测一个输入 x \bm{x} x的类别,我们只需要得到训练数据集中的 P ( c ) P(c) P(c) P ( x ∣ c ) P(\bm{x}|c) P(xc) 就可以了, P ( x ) P(x) P(x)不需要。

因为我们的预测过程是,给出一个样本输入 x x x,假设这个样本可能的类别为 C = { c ∣ c 0 , c 1 , . . . , c n } C=\{c|c_0,c_1,...,c_n\} C={ cc0,c1,...,cn}

我们根据贝叶斯公式,计算 P ( c 0 ) P(c_0) P(c0) P ( c 1 ) P(c_1) P(c1),……, P ( c n ) P(c_n) P(cn)

最后,我们选择一个最大的 P ( c ) P(c) P(c)作为我们最终的预测类别 c ′ c' c,模型预测结束。

而在这个过程中,我们的输入 x x x 是相同的,因此 P ( x ) P(\bm{x}) P(x) 也是相同的,所以我们不需要管它,就选择分子最大的就可以了。

3.1 计算 P ( c ) P(c) P(c)

这个很好计算,我们直接统计一下,数据集中不同类别的数据的频率就可以了,大数定律告诉我们,当数据集足够大的时候,我们可以使用频率来逼近概率。

3.2 计算 P ( x ∣ c ) P(\bm{x}|c) P(xc)

这个思路也很简单,同样是统计,统计数据集中,label属于 c c c 的,同时输入属性为 x \bm{x} x 的数据的频率,在数据量比较大的情况下,同样使用频率逼近概率。

但是,这里的 x \bm{x} x 是由不同的输入属性组合到一起最终合成的,它具有非常多种的情况,是组合问题,这种问题统计起来是会出现复杂度爆炸的情况的,无法在多项式的时间内完成程序的运算,换言之为一种 NP 难问题。

3.3 属性独立假设

为了解决这种 NP 难问题,我们采用了属性条件独立假设,进而就诞生了朴素贝叶斯

朴素贝叶斯分类通过属性独立假设,近似求解了 P ( x ∣ c ) P(\bm{x}|c) P(xc)这个NP难问题,而且近似的效果非常好,可以取得很棒的分类效果。

假设所有属性为独立分布的,我们可以得到如下式子:

P ( x ∣ c ) = ∏ i = 0 d P ( x i ∣ c ) (6) P(\bm{x}|c)=\prod_{i=0}^dP(x_i|c)\tag{6} P(xc)=i=0dP(xic)(6)

其中d为属性数目, x i x_i xi x \bm{x} x 在第 i i i 个属性上的取值。

将式(6)带入式(5)可以得到如下结果:

P ( c ∣ x ) = P ( x ∣ c ) P ( c ) P ( x ) = P ( c ) P ( x ) ∏ i = 0 d P ( x i ∣ c ) (7) P(c|\bm{x})=\frac{P(\bm{x}|c)P(c)}{P(\bm{x})}=\frac{P(c)}{P(\bm{x})}\prod_{i=0}^dP(x_i|c)\tag{7} P(cx)=P(x)P(xc)P(c)=P(x)P(c)i=0dP(xic)(7)

这里面的 P ( x i ∣ c ) P(x_i|c) P(xic) 是很有限的,它的数量可以表示为 ∑ i = 1 n n u m A t t r i b u t e s V a l u e s ( i ) × n u m C l a s s V a l u e s \sum_{i=1}^n numAttributesValues(i)\times numClassValues i=1nnumAttributesValues(i)×numClassValues。在程序设计中我们可以采用一个二维的List数组、三维double数组或者二维数组表示三维数组等多种方法来存储这个变量。(具体可以看下面代码,为了代码的可读性和简洁性,我是采用二维List来进行表示的)

其中 n n n表示输入 x \bm{x} x的属性数量, n u m A t t r i b u t e s V a l u e s ( i ) numAttributesValues(i) numAttributesValues(i) 表示第i个属性的可能取值的数量, n u m C l a s s V a l u e s numClassValues numClassValues 表示 label 有多少个类别。

从上面也可以看出,贝叶斯分类器天然是用来处理名词性属性的,如果我们遇到了数值型属性,就需要进行一下数据离散化处理,才能采用贝叶斯进行分类。数据离散化处理方法有很多,比如:等高分箱、等宽分享,以及基于概率分布的划分等。

四、基于weka的代码实现

package weka.classifiers.myf;

import weka.classifiers.Classifier;
import weka.core.*;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;

/**
 * @author YFMan
 * @Description 朴素贝叶斯 分类器
 * @Date 2023/5/14 18:48
 */
public class myNaiveBayes extends Classifier {
    
    

    // 用于存储 朴素贝叶斯 属性参数
    protected List<Integer>[][] m_Distributions;

    // 用于存储 朴素贝叶斯 类别参数
    protected List<Integer> m_ClassDistribution;

    // 类别参数 的 种类数量
    protected int m_NumClasses;

    // 存储训练数据
    protected Instances m_Instances;

    /*
     * @Author YFMan
     * @Description 训练分类器,初始化 属性参数 和 类别参数
     * @Date 2023/5/14 21:42
     * @Param [instances 训练数据]
     * @return void
     **/
    public void buildClassifier(Instances instances) throws Exception {
    
    
        // 初始化训练数据
        m_Instances = instances;
        // 初始化类别参数 的 种类数量
        m_NumClasses = instances.numClasses();

        // 初始化 属性参数
        m_Distributions = new List[instances.numAttributes() - 1][m_NumClasses];
        for(int i=0;i<instances.numAttributes() - 1;i++){
    
    
            for(int j=0;j<m_NumClasses;j++){
    
    
                m_Distributions[i][j] = new ArrayList<>();
            }
        }
        // 初始化 类别参数
        m_ClassDistribution = new ArrayList<>();
        for(int i=0;i<m_NumClasses;i++){
    
    
            m_ClassDistribution.add(0);
        }

        // 获取属性参数的枚举类型
        Enumeration attributeEnumeration = instances.enumerateAttributes();
        // 遍历属性参数
        while (attributeEnumeration.hasMoreElements()) {
    
    
            // 获取属性参数
            Attribute attribute = (Attribute) attributeEnumeration.nextElement();
            // 获取属性参数的索引
            int attributeIndex = attribute.index();
            // 获取属性参数的值的枚举类型
            Enumeration attributeValueEnumeration = attribute.enumerateValues();
            // 遍历属性参数的值
            while (attributeValueEnumeration.hasMoreElements()) {
    
    
                // 获取属性参数的值
                String attributeValue = (String) attributeValueEnumeration.nextElement();
                // 遍历类别参数
                for (int classIndex = 0; classIndex < m_NumClasses; classIndex++) {
    
    
                    // 初始化 属性参数 的 某个值 的 某个类别参数 的 计数
                    m_Distributions[attributeIndex][classIndex].add(0);
                }
            }
        }

        // 遍历训练数据
        for (int instanceIndex = 0; instanceIndex < instances.numInstances(); instanceIndex++) {
    
    
            // 获取训练数据的实例
            Instance instance = instances.instance(instanceIndex);
            // 获取训练数据的类别参数的值
            int classValue = (int) instance.classValue();
            // 遍历属性参数
            for (int attributeIndex = 0; attributeIndex < instances.numAttributes() - 1; attributeIndex++) {
    
    
                // 获取训练数据的属性参数的值
                int attributeValue = (int) instance.value(attributeIndex);
                // 计数
                m_Distributions[attributeIndex][classValue].set(attributeValue,
                        m_Distributions[attributeIndex][classValue].get(attributeValue) + 1);
            }
            // 计数
            m_ClassDistribution.set(classValue, m_ClassDistribution.get(classValue) + 1);
        }
    }

    /*
     * @Author YFMan
     * @Description 根据给定的实例,预测其类别
     * @Date 2023/5/14 21:43
     * @Param [instance 给定的实例]
     * @return double[]
     **/
    public double[] distributionForInstance(Instance instance)
            throws Exception {
    
    
        // 初始化预测概率数组
        double[] predictionProbability = new double[m_NumClasses];
        // 遍历类别参数
        for (int classIndex = 0; classIndex < m_NumClasses; classIndex++) {
    
    
            // 初始化预测概率
            double prediction = 1;
            // 遍历属性参数
            for (int attributeIndex = 0; attributeIndex < m_Instances.numAttributes() - 1; attributeIndex++) {
    
    
                // 获取属性参数的值
                int attributeValue = (int) instance.value(attributeIndex);
                // 获取 当前属性 可能的取值数
                int attributeValueCount = m_Distributions[attributeIndex][classIndex].size();
                // 计算条件概率P(x|c) (当前属性值在当前类别下占的比例) (拉普拉斯平滑)
                double p_x_c =  (double) (m_Distributions[attributeIndex][classIndex].get(attributeValue) + 1) /
                        (m_ClassDistribution.get(classIndex) + attributeValueCount);
                // 计算预测概率
                prediction *= p_x_c;
            }
            // 计算先验概率P(c) (当前类别占总类别的比例) (拉普拉斯平滑)
            double p_c = (double) (m_ClassDistribution.get(classIndex) + 1) /
                    (m_Instances.numInstances() + m_NumClasses);
            // 计算预测概率
            predictionProbability[classIndex] = prediction * p_c;
        }
        // 归一化
        Utils.normalize(predictionProbability);
        // 返回预测概率数组
        return predictionProbability;
    }

    /*
     * @Author YFMan
     * @Description 主函数
     * @Date 2023/5/14 21:54
     * @Param [argv 命令行参数]
     * @return void
     **/
    public static void main(String[] argv) {
    
    
        runClassifier(new myNaiveBayes(), argv);
    }
}

猜你喜欢

转载自blog.csdn.net/myf_666/article/details/130674476