机器学习实战决策树的java实现

package com.haolidong.Decisiontree;

import java.util.Comparator;
import java.util.HashMap;
import java.util.Map.Entry;
/**
 * @author haolidong
 * @Description:  [该类主要用于HashMap进行自定义的排序(从大到小)]  
 */
public class ComparatorImpl implements Comparator<HashMap<String,Integer>>{
	@SuppressWarnings("unchecked")
	@Override
	public int compare(HashMap<String, Integer> o1, HashMap<String, Integer> o2) {
		// TODO Auto-generated method stub
		Entry<String, Integer> obj1 = (Entry<String, Integer>) o1;  
		Entry<String, Integer> obj2 = (Entry<String, Integer>) o2;  
        return ((Integer) (obj2.getValue()) - (Integer) (obj1.getValue()));  
    }  
	

}

package com.haolidong.Decisiontree;

import java.util.ArrayList;
/**
 * 
 * @author haolidong
 * @Description:  [该类主要用于保存特征信息]
 * @parameter data:  [主要保存特征矩阵]
 */
public class Matrix {
	public  ArrayList<ArrayList<String>> data;
	public Matrix() {
		// TODO Auto-generated constructor stub
		data = new ArrayList<ArrayList<String>>();
	}
}

package com.haolidong.Decisiontree;

import java.util.ArrayList;
/**
 * 
 * @author haolidong
 * @Description:  [该类主要用于保存特征信息以及标签值]
 * @parameter labels:  [主要保存标签值]
 */
public class CreateDataSet extends Matrix{
	public  ArrayList<String> labels;
	public CreateDataSet() {
		// TODO Auto-generated constructor stub
		super();
		labels = new ArrayList<String>();
	}
	/**
	 * @author haolidong
	 * @Description:  [机器学习实战决策树第一个案例的数据] 
	 */
	public  void  initTest()
	{		
		ArrayList<String> ab1 = new ArrayList<String>();
		ArrayList<String> ab2 = new ArrayList<String>();
		ArrayList<String> ab3 = new ArrayList<String>();
		ArrayList<String> ab4 = new ArrayList<String>();
		ArrayList<String> ab5 = new ArrayList<String>();
		ab1.add("1");ab1.add("1");ab1.add("yes");
		ab2.add("1");ab2.add("1");ab2.add("yes");
		ab3.add("1");ab3.add("0");ab3.add("no");
		ab4.add("0");ab4.add("1");ab4.add("no");
		ab5.add("0");ab5.add("1");ab5.add("no");
		data.add(ab1);
		data.add(ab2);
		data.add(ab3);
		data.add(ab4);
		data.add(ab5);
		labels.add("no surfacing");
		labels.add("flippers");
	}
}

package com.haolidong.Decisiontree;

import java.util.ArrayList;
/**
 * 
 * @author haolidong
 * @Description:  [该类主要用于模拟Python的字典,最终保存生成树的信息]
 * @parameter  arrow:  [主要保存父节点指向自己的标签名字]
 * @parameter  name:  [主要保存当前节点的名字]
 * @parameter  arrDic:  [主要保存子节点的信息]
 */
public class Dictionary {
	public String arrow;
	public String name;
	public ArrayList<Dictionary> arrDic;
	/**
	 * @author haolidong
	 * @Description:  [类的构造函数,分配空间,根节点只要arrow什么也不填]
	 */
	public Dictionary() {
		// TODO Auto-generated constructor stub
		arrow = new String("");
		name = new String("");
		arrDic = new ArrayList<Dictionary>();
	}
}

package com.haolidong.Decisiontree;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;


	public class Decisiontree {
		/**
		 * @param args
		 * @author haolidong
		 * @Description:  [主函数主要对于各个实例进行测试]  
		 */
	public static void main(String[] args) {
		testCreateTree();
		testGlass();
	}
	
	/**
	 * @param inputTree 决策树
	 * @param testVec	测试向量【输入各个特征值进行测试】
	 * @return			返回最后的标签值
	 * @author 	        haolidong
	 * @Description:    [主函数主要对于各个实例进行测试]
	 */
	public static String classify(Dictionary inputTree,ArrayList<String> testVec){
		String result = new String();
		if(testVec.size()==0){
			result=inputTree.name;
		}else{
			for (int i = 0; i < inputTree.arrDic.size(); i++) {
				/*未来防止迭代没有结束,然后已经有返回值,这个时候后面的就不用继续进行了,testVec=0表示的是已经到达了叶子节点*/
				if(testVec.size()!=0){
					if(testVec.get(0).equals(inputTree.arrDic.get(i).arrow)){
						testVec.remove(testVec.get(0));
						result=classify(inputTree.arrDic.get(i),testVec);
					}
				}
				
			}
		}
		return result;
	}
	

	/**
	 * @param dataSet   数据集
	 * @param labels    分类的标签值
	 * @return          返回最终的决策树
	 * @author 	        haolidong
	 * @Description:    [生成决策树,当遇到标签值全部使用完,但是还是不能够把类完全分开,返回出现最多的标签值;
	 *                  当到达子节点的时候,也要跳出函数,这个分别是前两个if判断,每一次都选择信息增益最大的,
	 *                  然后递归进行划分,每一次递归都要去掉一个标签,一遍递归的终结  。 ]
	 */
	public static Dictionary createTree(Matrix dataSet,ArrayList<String> labels){
		ArrayList<String> classList = new ArrayList<String>();
		HashSet<String> setList = new HashSet<String>();
		String temps=new String("");
		for (int i = 0; i < dataSet.data.size(); i++) {
			temps = dataSet.data.get(i).get(dataSet.data.get(i).size()-1);
			classList.add(temps);
			setList.add(temps);
		}
		if(setList.size()==1){
			Dictionary dtemp = new Dictionary();
			dtemp.name = classList.get(0);
			return dtemp;
		}
		if(dataSet.data.get(0).size()==1){
			Dictionary stemp = new Dictionary();
			stemp.arrow = classList.get(0);
			return stemp;
		}
		int bestFeat = chooseBestFeatureToSplit(dataSet);
		String bestFeatLabel = labels.get(bestFeat);
		Dictionary myTree = new Dictionary();
		myTree.name=bestFeatLabel;
		labels.remove(bestFeat);
		ArrayList<String> featValues = new ArrayList<String>();
		HashSet<String> uniqueVals = new HashSet<String>();
		for (int i = 0; i < dataSet.data.size(); i++) {
			featValues.add(dataSet.data.get(i).get(bestFeat));
			uniqueVals.add(dataSet.data.get(i).get(bestFeat));
		}
		
		for (String value : uniqueVals) {
			ArrayList<String> subLabels = new ArrayList<String>();
			for (int j = 0; j < labels.size(); j++) {
				subLabels.add(labels.get(j));
			}
			Dictionary tempTree = new Dictionary();
			tempTree = createTree(splitDataSet(dataSet, bestFeat, value),subLabels);
			tempTree.arrow = value;
			myTree.arrDic.add(tempTree);
		}
		return myTree;
		
	}
	
	/**
	 * @param d
	 * @author 	        haolidong
	 * @Description:    [对于非叶子节输出他们自己的信息,然后判断字节点,子节点则直接输出]
	 */                  
	public static void displayDic(Dictionary d){
		if(d.arrDic.size()!=0){
			System.out.print("{"+d.name);
			if(d.arrDic.size()==0){
				System.out.print("}");
			}else{
				System.out.print(":");
				for (int i = 0; i < d.arrDic.size(); i++) {
					if(i==0)System.out.print("{");
					System.out.print(d.arrDic.get(i).arrow+":");
					displayDic(d.arrDic.get(i));
					if(i!=d.arrDic.size()-1){
						System.out.print(",");
					}
				}
				System.out.print("}");
				System.out.print("}");
			}
		}else {
			System.out.print(d.name);
			
		}
	}
	/**
	 * @param classList
	 * @return 返回当前出现次数最多的标签值
	 * @author 	        haolidong
	 * @Description:    [当且仅当标签全部用完时还没有把类别完全分离才使用的]
	 */
	public static Dictionary majorityCnt(ArrayList<String> classList){
		HashMap<String,Integer> classCount = new HashMap<String,Integer>();
		String vote;
		for (int i = 0; i < classList.size(); i++) {
			vote = classList.get(i);
			if(classCount.containsKey(vote)==true){
				classCount.put(vote, classCount.get(vote)+1);
			}else{
				classCount.put(vote, 1);
			}
		}
		ArrayList<HashMap.Entry<String,Integer>> entries= sortMap(classCount);
		Dictionary dtemp = new Dictionary();
		dtemp.name = entries.get(0).getKey();;
		return dtemp;
		
	}
	
	/**
	 * @param map       输入值是hashmap
	 * @return          返回排好序的map
	 * @author 	        haolidong
	 * @Description:    [对map的排序,这里是从大到小]
	 */
	public static ArrayList<HashMap.Entry<String,Integer>> sortMap(HashMap<String,Integer> map){  
	     List<HashMap.Entry<String, Integer>> entries = new ArrayList<HashMap.Entry<String, Integer>>(map.entrySet());  
	     Collections.sort(entries, new Comparator<HashMap.Entry<String, Integer>>() {  
	    	 public int compare(HashMap.Entry<String, Integer> obj1 , HashMap.Entry<String, Integer> obj2) {  
	             return obj2.getValue() - obj1.getValue();  
	         }  
	     });  
	      return (ArrayList<Entry<String, Integer>>) entries;  
	    }  
  
	/**
	 * @param DataSet   特征矩阵
	 * @return          返回需要切分的特征向量的下标
	 * @author 	        haolidong
	 * @Description:    [根据信息增益,选择最好的切分]
	 */
	public static int chooseBestFeatureToSplit(Matrix DataSet){
		int numFeatures = DataSet.data.get(0).size()-1;
		double baseEntropy = calcShannonEnt(DataSet);
		double bestInfoGain = 0.0;
		int bestFeature=-1;
		HashSet<String> uniqueVals = new HashSet<String>();
		for (int i = 0; i < numFeatures; i++) {
			uniqueVals.clear();
			for (int j = 0; j < DataSet.data.size(); j++) {
				uniqueVals.add(DataSet.data.get(j).get(i));
			}
			double newEntropy = 0.0;
			double prob = 0.0;
			for(String value:uniqueVals){
				Matrix subDataSet = new Matrix();
				subDataSet = splitDataSet(DataSet, i, value);
				prob = 1.0*subDataSet.data.size()/DataSet.data.size();
				newEntropy = newEntropy + prob * calcShannonEnt(subDataSet);
			}
			double infoGain = baseEntropy - newEntropy;
			if(infoGain > bestInfoGain){
				bestInfoGain = infoGain;
				bestFeature = i;
			}
		}
		return bestFeature;
	
	}
	/**
	 * @param DataSet  数据集
	 * @author haolidong
	 * @Description:  [求香农熵:H=[求和]-p(x)log2 p(x)] 
	 * @return 最后的香农熵
	 */
	public static double calcShannonEnt(Matrix DataSet){
		int numEntries = DataSet.data.size();
		HashMap<String,Integer> classCount = new HashMap<String,Integer>();
		String currentLabel;
		for (int i = 0; i < numEntries; i++) {
			currentLabel = DataSet.data.get(i).get(DataSet.data.get(i).size()-1);
			if(classCount.containsKey(currentLabel)==true){
				classCount.put(currentLabel, classCount.get(currentLabel)+1);
			}else{
				classCount.put(currentLabel, 1);
			}
		}
		double shannonEnt = 0.0;
		double prob = 0.0;
		for(HashMap.Entry<String,Integer> entry:classCount.entrySet()){
			prob = 1.0*entry.getValue()/numEntries;
			shannonEnt =shannonEnt -prob *Math.log(prob)/Math.log(2);
		}
		return shannonEnt;
	}
	/**
	 * @param dataSet   输入数据集
	 * @param axis      输入删除的列下标
	 * @param value     把低axis列下标为value的值删除以后,把这一行放入ArrayList
	 * @return          返回符合第axis列的特征向量为value的矩阵【删除了axis列】
	 * @author 	        haolidong
	 * @Description:    [返回符合第axis列的特征向量为value的矩阵【删除了axis列]
	 */
	public static Matrix splitDataSet(Matrix dataSet, int axis, String value){
		Matrix retDataSet = new Matrix();
		for (int i = 0; i < dataSet.data.size(); i++) {
			if(dataSet.data.get(i).get(axis).equals(value)){
				ArrayList<String> as = new ArrayList<String>();
				for (int j = 0; j < dataSet.data.get(i).size(); j++) {
					if(j!=axis){
						as.add(dataSet.data.get(i).get(j));
					}
				}
				retDataSet.data.add(as);
			}
		}
		return retDataSet;
	}
	/**
	 * @return          返回数据集
	 * @author 	        haolidong
	 * @Description:    [对香农熵的测试]
	 */
	public static CreateDataSet testShannon(){
		CreateDataSet DataSet = new CreateDataSet();
		DataSet.initTest();
		System.out.println(calcShannonEnt(DataSet));
		return DataSet;
	}
	/**
	 * @author 	        haolidong
	 * @Description:    [对分割数据集的测试]
	 */
	public static void testSplitDataSet() {
		CreateDataSet DataSet = new CreateDataSet();
		Matrix m =new Matrix();
		DataSet.initTest();
		m=splitDataSet(DataSet,0,"1");
		System.out.println(m);
	}
	/**
	 * @author 	        haolidong
	 * @Description:    [对最佳分割数据集的测试]
	 */
	public static void testChooseBestFeatureToSplit() {
		CreateDataSet DataSet = new CreateDataSet();
		DataSet.initTest();
		System.out.println(chooseBestFeatureToSplit(DataSet));
	}
	/**
	 * @author 	        haolidong
	 * @Description:    [对于当标签全部用完时还没有把类别完全分离的函数进行测试]
	 */
	public static void testmajortityCnt() {
		CreateDataSet DataSet = new CreateDataSet();
		DataSet.initTest();
		ArrayList<String> as = new ArrayList<String>();
		for (int i = 0; i < DataSet.data.size(); i++) {
			as.add(new String(DataSet.data.get(i).get(DataSet.data.get(i).size()-1)));
		}	
		majorityCnt(as);
	}
	/**
	 * @author 	        haolidong
	 * @Description:    [对决策树显示结果的测试]
	 */
	public static void testDisplayDir() {
		Dictionary d1 = new Dictionary();
		Dictionary d2 = new Dictionary();
		Dictionary d3 = new Dictionary();
		Dictionary d4 = new Dictionary();
		Dictionary d5 = new Dictionary();
//		Dictionary d6 = new Dictionary();
//		d6.name="hld";
//		d6.arrow="2";
		d1.arrow="0";
		d1.name="no";
		d2.arrow="1";
		d2.name="yes";		
		d3.arrow="1";
		d3.name="flippers";
		d3.arrDic.add(d1);
		d3.arrDic.add(d2);
//		d4.arrDic.add(d6);
		d4.name="no";
		d4.arrow="0";
		//root
		d5.name="no surfacing";
		d5.arrDic.add(d4);
		d5.arrDic.add(d3);
		displayDic(d5);
	}
	/**
	 * @author 	        haolidong
	 * @Description:    [验证决策树的分类效果]
	 */
	public static void testClassify() {
		CreateDataSet DataSet = new CreateDataSet();
		ArrayList<String> testVec = new ArrayList<String>();
		DataSet.initTest();
		Dictionary myTree = new Dictionary();
		myTree=createTree(DataSet,DataSet.labels);
		testVec.add("1");
		testVec.add("0");
//		displayDic(myTree);
		System.out.println(classify(myTree,testVec));
	}
	/**
	 * @author 	        haolidong
	 * @Description:    [对书上最后一个例子的测试【对于隐形眼镜的测试】]
	 */
	public static void testGlass(){
		String fileName = "I:\\machinelearninginaction\\Ch03\\lenses.txt";
		File file = new File(fileName);
		CreateDataSet DataSet = new CreateDataSet();
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = null;
            // 一次读入一行,直到读入null为文件结束
            while ((tempString = reader.readLine()) != null) {
                // 显示行号
                String[] strArr = tempString.split("\t");
                ArrayList<String> as = new ArrayList<String>();
                for (int i = 0; i < strArr.length; i++) {
					as.add(strArr[i]);
				}
                DataSet.data.add(as);
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        DataSet.labels.add(new String("age"));
        DataSet.labels.add(new String("prescript"));
        DataSet.labels.add(new String("astigmatic"));
        DataSet.labels.add(new String("tearRate"));
        Dictionary myTree = new Dictionary();
        myTree=createTree(DataSet,DataSet.labels);
		displayDic(myTree);
        
	}
	/**
	 * @author 	        haolidong
	 * @Description:    [对建树的测试]
	 */
	public static void testCreateTree() {
		CreateDataSet DataSet = new CreateDataSet();
		DataSet.initTest();
		Dictionary myTree = new Dictionary();
		myTree=createTree(DataSet,DataSet.labels);
		displayDic(myTree);
		
	}
}

猜你喜欢

转载自blog.csdn.net/qq_22125259/article/details/49309877