朴素贝叶斯分类器 拉普拉斯变换 java

/**
 * 朴素贝叶斯分类器   拉普拉变化的重要性(暂未实现)  小样本数据有坑   特征为离散型数值化
 * @author ysh 1208706282
 *
 */
public class NavieBayes {
    Map<Integer,Integer> labelInfo;
    Map<String,FeatureInfo> featureInfo;
    List<Sample> samples;
    static class Sample{
        int label;
        List<Integer> feature;
    }
    static class FeatureInfo{
        int label;
        int featureId;
        int featureValue;
        int count;
        double rate;
    }
    /**
     * 加载数据
     * @param path
     * @param regex
     * @throws Exception
     */
    public  void loadData(String path,String regex) throws Exception{
        samples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(path));
        String line = null;
        String splits[] = null;
        Sample sample = null;
        while(null != (line=reader.readLine())){
            splits = line.split(regex);
            sample = new Sample();
            sample.feature = new ArrayList<Integer>(splits.length-1);
            for(int i=0;i<splits.length-1;i++){
                sample.feature.add(new Integer(splits[i]));
            }
            sample.label = Integer.valueOf(splits[splits.length-1]);
            samples.add(sample);
        }
        reader.close();
    }
    /**
     * 加载验证测试集
     * @param path
     * @param regex
     * @throws Exception
     */
    public  List<Sample> loadTestData(String path,boolean hasLabel,String regex) throws Exception{
        List<Sample> samples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(path));
        String line = null;
        String splits[] = null;
        Sample sample = null;
        while(null != (line=reader.readLine())){
            splits = line.split(regex);
            sample = new Sample();
            sample.feature = new ArrayList<Integer>(splits.length-1);
            for(int i=0;i<splits.length-1;i++){
                sample.feature.add(new Integer(splits[i]));
            }
            if(hasLabel){
                sample.label = Integer.valueOf(splits[splits.length-1]);
            }
            samples.add(sample);
        }
        reader.close();
        return samples;
    }
    public void laplaceSmooth(){
        
    }
    public void train(){
        featureInfo = new HashMap<String,FeatureInfo>();
        labelInfo = new HashMap<Integer,Integer>();
        
        String key = null;
        FeatureInfo info = null;
        
        for(Sample sample:samples) {
            if(null == labelInfo.get(sample.label)){
                labelInfo.put(sample.label, 1);
            }else{
                labelInfo.put(sample.label, labelInfo.get(sample.label)+1);
            }
            for(int i=0;i<sample.feature.size();i++){
                key = sample.label+";"+i+";"+sample.feature.get(i);
                info = featureInfo.get(key);
                if(null == info){
                    info = new FeatureInfo();
                    info.count = 1;
                    info.featureId = i;
                    info.featureValue = sample.feature.get(i);
                    info.label = sample.label;
                    featureInfo.put(key, info);
                }else{
                    info.count += 1;
                }
            }
        }
        Iterator<Entry<Integer,Integer>> iter = labelInfo.entrySet().iterator();
        Entry<Integer,Integer> entry = null;
        while(iter.hasNext()){
            entry = iter.next();
            System.out.println("label: "+entry.getKey()+" count: "+entry.getValue());
        }
        
        Set<String> set = featureInfo.keySet();
        for(String str:set){
            System.out.println(str+" count:"+featureInfo.get(str).count);
        }
    }
    public int classify(Sample sample){
        int label = 0;
        double max = -1;
        String key = null;
        FeatureInfo info = null;
        Set<Integer> set = labelInfo.keySet();
        for(Integer la:set){
            double rate = 1;
            for(int i=0;i<sample.feature.size();i++){
                key = la.intValue()+";"+i+";"+sample.feature.get(i);
                info = featureInfo.get(key);
                if(info != null){
                    rate *= (1.0*info.count/labelInfo.get(la));
                }else{
                    //System.out.println("error");
                    rate *= (1.0/labelInfo.get(la));
                }
            }
            rate *= (1.0*labelInfo.get(la)/samples.size());
            if(rate > max){
                max = rate;
                label = la.intValue();
            }
            //System.out.println("label: "+la+" rate:"+rate);
        }
        
        return label;
    }
    /**
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        // TODO Auto-generated method stub
        String pathTrain = "F:/uci/data/car/car.train_train";
        String pathTest = "F:/uci/data/car/car.train_test";
        
        NavieBayes nb = new NavieBayes();
        nb.loadData(pathTrain, ",");
        nb.train();
        List<Sample> test = nb.loadTestData(pathTest,true,",");
        int count = 0;
        for(Sample sample:test){
            int predict = nb.classify(sample);
            System.out.println("label: "+sample.label+" predict: "+predict);
            if(predict == sample.label){
                count++;
            }
        }
        System.out.println("right rate: "+(count*1.0/test.size()));
    }
}


猜你喜欢

转载自blog.csdn.net/ysh126/article/details/53100543