贝叶斯网络及朴素贝叶斯网络的实现

	核心决策理论:选择概率最高的一类作为决策.即:在出现一个需要分类的新点时,我们只需要计算这个点:max(p(c1|x,y),p(c2|x,y),p(c3| x,y)...p(cn |x,y))。其对应的最大概率标签,就是这个新点的分类。

贝叶斯
.
在这里插入图片描述

package baseNaiveBayesian;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public abstract class NaiveBayesianBase {
	protected ArrayList<ArrayList<String>> trainingSet;
	
	public NaiveBayesianBase() {
		// TODO 自动生成的构造函数存根
		trainingSet=new ArrayList<ArrayList<String>>();
	}
	
	public abstract int inputTrainingSet() throws IOException;
	
	public abstract int readTrainingSet(String path) throws IOException;
	
	public static Map<String, ArrayList<ArrayList<String>>> dataClassification(ArrayList<ArrayList<String>> data) {//按照最后一个值分类
		Map<String, ArrayList<ArrayList<String>>> map=new HashMap<String, ArrayList<ArrayList<String>>>();
		ArrayList<String> line=null;
		String word="";
		for(int i=0;i<data.size();++i) {
			line=data.get(i);
			word=line.get(line.size()-1);
			if(map.containsKey(word)) map.get(word).add(line);
			else {
				ArrayList<ArrayList<String>> newLine=new ArrayList<ArrayList<String>>();
				newLine.add(line);
				map.put(word, newLine);
			}
		}
		Object[] c=map.keySet().toArray();
		//for(int i=0;i<c.length;++i) System.out.println(c[i].toString()+","+map.get(c[i]).size());
		return map;
	}
	
	public String predictClassification(ArrayList<String> testSet) {
		Map<String, ArrayList<ArrayList<String>>> doc=dataClassification(trainingSet);
		//保存训练集属性于数组中
		Object[] classificationAttributes=doc.keySet().toArray();
		double maxP=0.00;
		int maxPIndex=-1;
		for(int i=0;i<doc.size();++i) {
			String word=classificationAttributes[i].toString();
			ArrayList<ArrayList<String>> line=doc.get(word);
			BigDecimal b1=new BigDecimal(Double.toString(line.size()));
			BigDecimal b2=new BigDecimal(Double.toString(trainingSet.size()));
			double pClassification=b1.divide(b2, 3, RoundingMode.HALF_UP).doubleValue();
			int cn=trainingSet.get(0).size()-1>testSet.size()?testSet.size():trainingSet.get(0).size()-1;
			for(int k=0;k<cn;++k) {
				double pCA=pOfClassificationAttributes(testSet.get(k), k,classificationAttributes[i].toString());
				if(pCA<=0.00) pCA=1.0/doc.get(classificationAttributes[i].toString()).size();
				pClassification=new BigDecimal(Double.toString(pClassification)).multiply(new BigDecimal(Double.toString(pCA))).doubleValue();
			}
			if(pClassification>maxP) {
				maxP=pClassification;
				maxPIndex=i;
			}
		}
		//System.out.println(classificationAttributes[maxPIndex].toString());
		return classificationAttributes[maxPIndex].toString();
	}
	
	public double pOfClassificationAttributes(String attribute,int index,String classificationclass) {
		double p=0.0;
		int count=0;
		int total=0;
		for(int i=0;i<trainingSet.size();++i) {
			if(trainingSet.get(i).get(trainingSet.get(i).size()-1).equals(classificationclass)) {
				++total;
				if(trainingSet.get(i).get(index).equals(attribute)) ++count;
			}
		}
		BigDecimal b1=new BigDecimal(Double.toString(count));
		BigDecimal b2=new BigDecimal(Double.toString(total));
		p=b1.divide(b2, 3, RoundingMode.HALF_UP).doubleValue();
		//System.out.println(total+" "+count+"\t"+attribute+"\t"+classificationclass);
		return p;
	}
	
	public void reportModel(double d) { //比例为d的数据做测试集
		if(d<0.0||d>1.0) return;
		ArrayList<ArrayList<String>> testSet=new ArrayList<ArrayList<String>>();
		int testSetCount=(int) (trainingSet.size()*d);
		for(int i=0;i<testSetCount;++i) testSet.add(trainingSet.remove((int)(Math.random()*(trainingSet.size()-1))));
		
		Map<String, Integer> counts=new HashMap<String, Integer>();
		Map<String, Integer> real=new HashMap<String, Integer>();
		Map<String, ArrayList<ArrayList<String>>> doc=dataClassification(testSet);
		Object[] objects=doc.keySet().toArray();
		for(int i=0;i<objects.length;++i) {
			real.put(objects[i].toString(), doc.get(objects[i]).size());
			counts.put(objects[i].toString(), 0);
		}
		for(int i=0;i<testSet.size();++i) {
			String key=predictClassification(testSet.get(i));
			counts.replace(key, counts.get(key)+1);
		}
		
		double p=0.0;
		for(int i=0;i<objects.length;++i)
			p+=((double)Math.abs(real.get(objects[i])-counts.get(objects[i])))/testSetCount;
		System.out.println("模型准确率为:"+(1.0-p));
		for(int i=0;i<testSetCount;++i) trainingSet.add(testSet.get(i));
	}
	
	public void reportModelSelf() {
		Map<String, Integer> counts=new HashMap<String, Integer>();
		Map<String, Integer> real=new HashMap<String, Integer>();
		Map<String, ArrayList<ArrayList<String>>> doc=dataClassification(trainingSet);
		Object[] objects=doc.keySet().toArray();
		for(int i=0;i<objects.length;++i) {
			real.put(objects[i].toString(), doc.get(objects[i]).size());
			counts.put(objects[i].toString(), 0);
		}
		for(int i=0;i<trainingSet.size();++i) {
			String key=predictClassification(trainingSet.get(i));
			counts.replace(key, counts.get(key)+1);
		}
		
		double p=0.0;
		for(int i=0;i<objects.length;++i)
			p+=((double)Math.abs(real.get(objects[i])-counts.get(objects[i])))/trainingSet.size();
		System.out.println("模型准确率为:"+(1.0-p));
	}
}

`
//病人数据资料分类
package classification;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;

import baseNaiveBayesian.NaiveBayesianBase;

public class PatientClassification extends NaiveBayesianBase {

	@Override
	public int inputTrainingSet() throws IOException {
		// TODO 自动生成的方法存根
		BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));  
        String str = "";  
        while (!(str = reader.readLine()).equals("")) {  
            String[] tokenizer = str.split(",");  
            ArrayList<String> s = new ArrayList<String>();  
            for(int i=0;i<tokenizer.length;i++){
                s.add(tokenizer[i]);
            } 
            trainingSet.add(dataDeal(s));  
        }
		return 0;
	}

	@Override
	public int readTrainingSet(String path) throws IOException {
		// TODO 自动生成的方法存根
		File file=new File(path);
		if(!file.exists()||!file.isFile()) {
			System.out.println(file.getAbsolutePath());
			return -1;
		}
		BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
        String str = "";  
        while ((str=reader.readLine())!=null) {
            String[] tokenizer = str.split(",");  
            ArrayList<String> s = new ArrayList<String>();  
            for(int i=0;i<tokenizer.length;i++){
                s.add(tokenizer[i]);
            }
            trainingSet.add(dataDeal(s));  
        }  
        reader.close();
		return 0;
	}
	
	private static ArrayList<String> dataDeal(ArrayList<String> line) {
		ArrayList<String> newLine=new ArrayList<String>();
		int temp=-1;
		double tempDouble=0.0;
		//Age
		switch ((Integer.parseInt(line.get(0))+2)/5) {
		case 0:case 1:case 2:case 3:case 4:case 5:case 6:newLine.add("1");break;
		case 7:newLine.add("2");break;
		case 8:newLine.add("3");break;
		case 9:newLine.add("4");break;
		case 10:newLine.add("5");break;
		case 11:newLine.add("6");break;
		case 12:newLine.add("7");break;
		default:newLine.add("-1");break;
		}
		//Gender
		newLine.add(line.get(1));
		//BMI
		temp=Integer.parseInt(line.get(2));
		if(temp<18) newLine.add("1");
		else if(temp<25) newLine.add("2");
		else newLine.add(String.valueOf(temp/5-2));
		//Fever,Nausea/Vomiting,Headache,Diarrhea,Fatigue & Bone ache,Jaundice,Epigastria pain 7 Absent/Present
		for(int i=0;i<7;++i) newLine.add(line.get(i+3));
		//WBC 10
		temp=Integer.parseInt(line.get(10));
		if(temp<4000) newLine.add("1");
		else if(temp<11000) newLine.add("2");
		else newLine.add("3");
		//RBC 11
		tempDouble=Double.parseDouble(line.get(11));
		if(tempDouble<3000000.00) newLine.add("1");
		else if(tempDouble<5000000.00) newLine.add("2");
		else newLine.add("3");
		//HGB 12
		temp=Integer.parseInt(line.get(12));
		if(newLine.get(1).equals("1")) {
			if(temp<14) newLine.add("1");
			else if(temp<=17) newLine.add("2");
			else newLine.add("3");
		}
		else {
			if(temp<12) newLine.add("1");
			else if(temp<=15) newLine.add("2");
			else newLine.add("3");
		}
		//Plat
		tempDouble=Double.parseDouble(line.get(13));
		if(tempDouble<100000.00) newLine.add("1");
		else if(tempDouble<255000) newLine.add("2");
		else newLine.add("3");
		//AST1,ALT1,ALT4,ALT12,ALT24,ALT36,ALT48,ALT after 24  8
		for(int i=0;i<8;++i) {
			tempDouble=Double.parseDouble(line.get(i+14));
			if(tempDouble<20.00) newLine.add("1");
			else if(tempDouble<=40.00) newLine.add("2");
			else newLine.add("3");
		}
		//RNA Base,RNA 4,RNA 12,RNA EOT,RNA EF  22 5
		for(int i=0;i<5;++i) {
			tempDouble=Double.parseDouble(line.get(22+i));
			if(tempDouble<=5.00) newLine.add("1");
			else newLine.add("2");
		}
		//Baseline Histological Grading 27
		newLine.add(line.get(27));
		//Baseline Histological 28 分类4类
		if(line.size()>newLine.size()) newLine.add(line.get(line.size()-1));
		return newLine;
	}
	
	public static void main(String[] args) {
		PatientClassification patientClassification=new PatientClassification();
		try {
			patientClassification.readTrainingSet("patientData.txt");
			patientClassification.reportModelSelf();
			patientClassification.reportModel(0.10);
		} catch (IOException e) {
			// TODO 自动生成的 catch 块
			e.printStackTrace();
		}
	}
}

//参考数据:[测试数据](https://archive.ics.uci.edu/ml/datasets/Hepatitis+C+Virus+%28HCV%29+for+Egyptian+patients)
发布了8 篇原创文章 · 获赞 0 · 访问量 634

猜你喜欢

转载自blog.csdn.net/qq_34262612/article/details/103677716