hadoop-朴素贝叶斯算法的简单实现

文章转自:https://blog.csdn.net/Angelababy_huan/article/details/53046151

  贝叶斯分类器的分类原理是通过某对象的先验概率,利用贝叶斯公式计算出其后验概率,即该对象属于某一类的概率,选择具有最大后验概率的类作为该对象所属的类。    

    以下为一个简单的例子:

    数据:天气情况和每天是否踢足球的记录表

日期 踢足球 天气 温度 湿度 风速
1号 否(0) 晴天(0) 热(0) 高(0) 低(0)
2号 否(0) 晴天(0) 热(0) 高(0) 高(1)
3号 是(1) 多云(1) 热(0) 高(0) 低(0)
4号 是(1) 下雨(2) 舒适(1) 高(0) 低(0)
5号 是(1) 下雨(2) 凉爽(2) 正常(1) 低(0)
6号 否(0) 下雨(2) 凉爽(2) 正常(1) 高(1)
7号 是(1) 多云(1) 凉爽(2) 正常(1) 高(1)
8号 否(0) 晴天(0) 舒适(1) 高(0) 低(0)
9号 是(1) 晴天(0) 凉爽(2) 正常(1) 低(0)
10号 是(1) 下雨(2) 舒适(1) 正常(1) 低(0)
11号 是(1) 晴天(0) 舒适(1) 正常(1) 高(1)
12号 是(1) 多云(1) 舒适(1) 高(0) 高(1)
13号 是(1) 多云(1) 热(0) 正常(1) 低(0)
14号 否(0) 下雨(2) 舒适(1) 高(0) 高(1)
15号 晴天(0) 凉爽(2) 高(0) 高(1)
    需要预测15号,在这种天气情况下是否踢球。

    假设15号去踢球,踢球的概率计算过程如下:

    P(踢球的概率) = 9/14

    P(晴天|踢) = 踢球天数中晴天踢球的次数/踢球次数 = 2/9

    P(凉爽|踢) = 踢球天数中凉爽踢球的次数/踢球次数 = 3/9

    P(湿度高|踢) = 踢球天数中湿度高踢球的次数/踢球次数 = 3/9

    P(风速高|踢) = 踢球天数中风速高踢球的次数/踢球次数 = 3/9

    则15号踢球的概率P = 9/14 * 2/9 * 3/9 * 3/9 * 3/9 = 0.00529

    按照上述步骤还可计算出15号不去踢球的概率P = 5/14 * 3/5 * 1/5 * 4/5 * 3/5 = 0.02057

    可以看出,15号不去踢球的概率大于去踢球的概率,则可预测说,15号不去踢球。

    理解朴素贝叶斯的流程之后,开始设计MR程序。在Mapper中,对训练数据进行拆分,也就是将这条训练数据拆分为类别和训练数据,将训练数据以自定义值类型来保存,然后传递给Reducer。

                

Mapper:

import java.io.IOException;   
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;  
import org.apache.hadoop.mapreduce.Mapper;  
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BayesMapper extends Mapper<Object, Text, IntWritable, MyWritable> {  
    Logger log = LoggerFactory.getLogger(BayesMapper.class);  
    private IntWritable myKey = new IntWritable();  
    private MyWritable myValue = new MyWritable();
    @Override  
    protected void map(Object key, Text value, Context context)  
            throws IOException, InterruptedException {  
        log.info("***"+value.toString());  
        int[] values = getIntData(value);  
        int label = values[0];  //存放类别  
        int[] result = new int[values.length-1]; //存放数据  
        for(int i =1;i<values.length;i++){  
            result[i-1] = values[i];
        }  
        myKey.set(label);  
        myValue.setValue(result);  
        context.write(myKey, myValue);  
    }  
    private int[] getIntData(Text value) {  
        String[] values = value.toString().split(",");  
        int[] data = new int[values.length];
        for(int i=0; i < values.length;i++){
        	if(!values[i].equals(""))
        		if(values[i].matches("^[0-9]+$"))
        			data[i] = Integer.parseInt(values[i]);  
        }  
        return data;  
    }  
}  

MyWritable:

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;

public class MyWritable implements Writable{  
    private int[] value;  
    public MyWritable() {  
        
    }  
    public MyWritable(int[] value){  
        this.setValue(value);  
    } 
    public void write(DataOutput out) throws IOException {  
        out.writeInt(value.length);  
        for(int i=0; i<value.length;i++){  
            out.writeInt(value[i]);  
        }  
    }   
    public void readFields(DataInput in) throws IOException {  
        int vLength = in.readInt();  
        value = new int[vLength];  
        for(int i=0; i<vLength;i++){  
            value[i] = in.readInt();  
        }  
    }  
    public int[] getValue() {  
        return value;  
    }  
    public void setValue(int[] value) {  
        this.value = value;  
    }  
}  

Reducer:

import java.io.BufferedReader;
import java.io.IOException;  
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;  
import org.apache.hadoop.conf.Configuration;  
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;  
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Reducer;  
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BayesReducer extends Reducer<IntWritable, MyWritable, IntWritable, IntWritable>{  
    Logger log = LoggerFactory.getLogger(BayesReducer.class);  
    private String testFilePath;  
    // 测试数据  
    private ArrayList<int[]> testData = new ArrayList<>();  
    // 保存相同k的所有数据  
    private ArrayList<CountAll> allData = new ArrayList<>();  
    @Override  
    protected void setup(Context context)  
            throws IOException, InterruptedException {  
        Configuration conf = context.getConfiguration();  
        testFilePath = conf.get("home/5.txt");  
        Path path = new Path("home/5.txt");  
        FileSystem fs = path.getFileSystem(conf);  
        readTestData(fs,path);  
    }  
    @Override  
    protected void reduce(IntWritable key, Iterable<MyWritable> values,  
            Context context)  
            throws IOException, InterruptedException {  
        Double[] myTest = new Double[testData.get(0).length-1];  
        for(int i=0;i<myTest.length;i++){  
            myTest[i] = 1.0;  
        }  
        Long sum = 2L;  
        // 计算每个类别中,每个属性值为1的个数  
        for (MyWritable myWritable : values) {  
            int[] myvalue = myWritable.getValue();  
            for(int i=0; i < myvalue.length;i++){  
                myTest[i] += myvalue[i];  
            }  
            sum += 1;  
        }  
        for(int i=0;i<myTest.length;i++){  
            myTest[i] = myTest[i]/sum;  
        }  
        allData.add(new CountAll(sum,myTest,key.get()));  
    }  
    private IntWritable myKey = new IntWritable();  
    private IntWritable myValue = new IntWritable();  
      
    protected void cleanup(Context context)  
            throws IOException, InterruptedException {  
        // 保存每个类别的在训练数据中出现的概率  
        // k,v  0,0.4  
        // k,v  1,0.6  
        HashMap<Integer, Double> labelG = new HashMap<>();  
        Long allSum = getSum(allData); //计算训练数据的长度  
        for(int i=0; i<allData.size();i++){  
            labelG.put(allData.get(i).getK(),   
                    Double.parseDouble(allData.get(i).getSum().toString())/allSum);  
        }  
        //test的长度 要比训练数据中的长度大1  
        int sum = 0;  
        int yes = 0;  
        for(int[] test: testData){  
            int value = getClasify(test, labelG);  
            if(test[0] == value){  
                yes += 1;  
            }  
            sum +=1;  
            myKey.set(test[0]);  
            myValue.set(value);  
            context.write(myKey, myValue);  
        }  
        System.out.println("正确率为:"+(double)yes/sum);  
    }  
    /*** 
     * 求得所有训练数据的条数 
     * @param allData2 
     * @return 
     */  
    private Long getSum(ArrayList<CountAll> allData2) {  
        Long allSum = 0L;  
        for (CountAll countAll : allData2) {  
            log.info("类别:"+countAll.getK()+"数据:"+myString(countAll.getValue())+"总数:"+countAll.getSum());  
            allSum += countAll.getSum();  
        }  
        return allSum;  
    }  
    /*** 
     * 得到分类的结果 
     * @param test 
     * @param labelG 
     * @return 
     */  
    private int getClasify(int[] test,HashMap<Integer, Double> labelG ) {  
        double[] result = new double[allData.size()]; //以类别的长度作为数组的长度  
        for(int i = 0; i<allData.size();i++){  
            double count = 0.0;  
            CountAll ca = allData.get(i);  
            Double[] pdata = ca.getValue();  
            for(int j=1;j<test.length;j++){  
                if(test[j] == 1){  
                    // 在该类别中,相同位置上的元素的值出现1的概率  
                    count += Math.log(pdata[j-1]);   
                }else{  
                    count += Math.log(1- pdata[j-1]);   
                }  
                log.info("count: "+count);  
            }  
            count += Math.log(labelG.get(ca.getK()));  
            result[i] = count;  
        }   
        if(result[0] > result[1]){  
            return 0;  
        }else{  
            return 1;  
        }  
    }  
    /*** 
     * 读取测试数据 
     * @param fs 
     * @param path 
     * @throws NumberFormatException 
     * @throws IOException 
     */  
    private void readTestData(FileSystem fs, Path path) throws NumberFormatException, IOException {  
        FSDataInputStream data = fs.open(path);  
        BufferedReader bf = new BufferedReader(new InputStreamReader(data));  
        String line = "";  
        while ((line = bf.readLine()) != null) {  
            String[] str = line.split(",");  
            int[] myData = new int[str.length];  
            for(int i=0;i<str.length;i++){
            	if(str[i]!=""||!str[i].equals(""))
            		if(str[i].matches("^[0-9]+$"))
                myData[i] = Integer.parseInt(str[i]);  
            }  
            testData.add(myData);  
        }  
        bf.close();  
        data.close();  
          
    }  
    public static String myString(Double[] arr){  
        String num = "";  
        for(int i=0;i<arr.length;i++){  
            if(i==arr.length-1){  
                num += String.valueOf(arr[i]);  
            }else{  
                num += String.valueOf(arr[i])+',';  
            }  
        }  
        return num;  
    }  
}  

CountAll:

public class CountAll {  
    private Long sum;  
    private Double[] value;  
    private int k;  
    public CountAll(){}  
    public CountAll(Long sum, Double[] value,int k){  
        this.sum = sum;  
        this.value = value;  
        this.k = k;  
    }  
    public Double[] getValue() {  
        return value;  
    }  
    public void setValue(Double[] value) {  
        this.value = value;  
    }  
    public Long getSum() {  
        return sum;  
    }  
    public void setSum(Long sum) {  
        this.sum = sum;  
    }  
    public int getK() {  
        return k;  
    }  
    public void setK(int k) {  
        this.k = k;  
    }  
}  

MainJob:

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
public class MainJob {
	public static void main(String[] args) throws Exception {  
        Configuration conf = new Configuration();  
        String[] otherArgs = new GenericOptionsParser(conf, args)  
                .getRemainingArgs();  
        if (otherArgs.length != 2) {  
            System.err.println("Usage: numbersum <in> <out>");  
            System.exit(2);  
        }  
        long startTime = System.currentTimeMillis();// 计算时间  
        Job job = new Job(conf);  
        job.setJarByClass(MainJob.class);  
        job.setMapperClass(BayesMapper.class);  
        job.setReducerClass(BayesReducer.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(MyWritable.class);
        job.setOutputKeyClass(IntWritable.class);  
        job.setOutputValueClass(MyWritable.class);
        FileInputFormat.addInputPath(job, new Path(otherArgs[0]));  
        FileOutputFormat.setOutputPath(job, new Path(otherArgs[1]));  
        job.waitForCompletion(true);  
        long endTime = System.currentTimeMillis();  
        System.out.println("time=" + (endTime - startTime));  
        System.exit(0);  
    }  

}

测试数据:

1,0,0,0,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,0  
1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,1  
1,1,0,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,1  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,1,1,0,0,0,0,0,0  
1,0,0,0,1,0,0,0,0,1,0,0,0,1,1,0,1,0,0,0,1,0,1  
1,1,0,1,1,0,0,0,1,0,1,0,1,1,0,0,0,0,0,0,0,1,1  
1,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,1  
1,0,0,1,0,0,0,1,1,0,0,0,0,1,0,1,0,0,0,0,0,1,1  
1,0,1,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0  
1,1,1,0,0,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,1,0,1  
1,1,1,0,0,1,1,1,0,1,1,1,1,0,1,0,0,1,0,1,1,0,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,1  
1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,1,0,0,1,1  
1,1,0,1,1,0,0,1,1,1,0,1,1,1,1,1,1,0,1,1,0,1,1  
1,0,1,1,0,0,1,1,1,0,0,0,1,1,0,0,1,1,1,0,1,1,1  
1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,1,0,0,0,0,1,0  
1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1  
1,1,0,1,0,1,0,1,1,0,1,0,1,1,0,0,0,1,0,0,1,1,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0  
1,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,0,0,0,0,1,1  
1,1,0,0,0,1,1,0,1,0,0,1,0,0,0,0,0,0,0,1,1,0,0  
1,1,1,0,0,1,1,1,0,0,1,1,1,0,0,0,0,0,0,1,0,0,0  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0  
1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0  

验证数据:

1,1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,0,0,1,1,0,0  
1,1,0,0,1,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,1,0,1,0,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1  
1,0,1,1,1,0,0,1,0,1,0,0,1,1,1,0,1,0,0,0,0,1,0  
1,0,0,1,0,0,0,0,1,0,0,1,0,1,1,0,1,0,0,0,0,0,1  
1,0,0,1,1,0,1,0,0,1,0,1,0,1,0,0,1,0,0,0,0,1,1  
1,1,0,0,1,0,0,1,1,1,1,0,1,1,1,0,1,0,0,0,1,0,1  
1,1,0,0,1,0,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0  
1,1,0,0,1,1,1,0,0,1,1,1,0,0,1,0,1,1,0,1,0,0,0  
1,1,0,0,0,1,0,0,0,1,1,0,0,1,1,1,0,0,0,1,0,0,0  
1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0  
1,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1  
1,1,0,0,0,1,0,0,0,1,1,0,0,0,1,0,0,0,1,1,0,0,0  
1,1,0,0,1,1,0,0,0,1,1,0,0,0,0,0,1,0,0,1,1,0,0  
1,1,0,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,1,0,0,1,0  
1,1,1,0,0,1,1,1,1,0,1,1,1,1,0,0,0,1,0,0,0,1,1  
1,1,0,0,0,0,1,1,0,0,1,1,1,0,0,0,0,1,0,0,0,0,1  
1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0  
1,1,1,1,0,1,0,1,1,0,1,0,1,1,0,0,1,0,0,0,1,1,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,1,0,0  
1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,1,0,1,1,1  
1,0,0,1,1,1,0,0,1,1,1,0,0,1,1,1,1,0,1,0,1,1,0  
1,1,1,0,1,1,1,1,0,0,0,1,1,0,0,0,1,1,0,0,1,0,0  
1,1,1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0  
1,1,1,0,0,1,1,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0  
1,1,0,1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,1,0  
1,1,1,1,1,0,1,1,1,0,1,0,0,1,1,1,1,0,0,1,1,0,0 

运行结果:



猜你喜欢

转载自blog.csdn.net/zw159357/article/details/80369406
今日推荐