package com.haolidong.Decisiontree;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map.Entry;
/**
* @author haolidong
* @Description: [该类主要用于HashMap进行自定义的排序(从大到小)]
*/
public class ComparatorImpl implements Comparator<HashMap<String,Integer>>{
@SuppressWarnings("unchecked")
@Override
public int compare(HashMap<String, Integer> o1, HashMap<String, Integer> o2) {
// TODO Auto-generated method stub
Entry<String, Integer> obj1 = (Entry<String, Integer>) o1;
Entry<String, Integer> obj2 = (Entry<String, Integer>) o2;
return ((Integer) (obj2.getValue()) - (Integer) (obj1.getValue()));
}
}
package com.haolidong.Decisiontree;
import java.util.ArrayList;
/**
*
* @author haolidong
* @Description: [该类主要用于保存特征信息]
* @parameter data: [主要保存特征矩阵]
*/
public class Matrix {
public ArrayList<ArrayList<String>> data;
public Matrix() {
// TODO Auto-generated constructor stub
data = new ArrayList<ArrayList<String>>();
}
}
package com.haolidong.Decisiontree;
import java.util.ArrayList;
/**
*
* @author haolidong
* @Description: [该类主要用于保存特征信息以及标签值]
* @parameter labels: [主要保存标签值]
*/
public class CreateDataSet extends Matrix{
public ArrayList<String> labels;
public CreateDataSet() {
// TODO Auto-generated constructor stub
super();
labels = new ArrayList<String>();
}
/**
* @author haolidong
* @Description: [机器学习实战决策树第一个案例的数据]
*/
public void initTest()
{
ArrayList<String> ab1 = new ArrayList<String>();
ArrayList<String> ab2 = new ArrayList<String>();
ArrayList<String> ab3 = new ArrayList<String>();
ArrayList<String> ab4 = new ArrayList<String>();
ArrayList<String> ab5 = new ArrayList<String>();
ab1.add("1");ab1.add("1");ab1.add("yes");
ab2.add("1");ab2.add("1");ab2.add("yes");
ab3.add("1");ab3.add("0");ab3.add("no");
ab4.add("0");ab4.add("1");ab4.add("no");
ab5.add("0");ab5.add("1");ab5.add("no");
data.add(ab1);
data.add(ab2);
data.add(ab3);
data.add(ab4);
data.add(ab5);
labels.add("no surfacing");
labels.add("flippers");
}
}
package com.haolidong.Decisiontree;
import java.util.ArrayList;
/**
*
* @author haolidong
* @Description: [该类主要用于模拟Python的字典,最终保存生成树的信息]
* @parameter arrow: [主要保存父节点指向自己的标签名字]
* @parameter name: [主要保存当前节点的名字]
* @parameter arrDic: [主要保存子节点的信息]
*/
public class Dictionary {
public String arrow;
public String name;
public ArrayList<Dictionary> arrDic;
/**
* @author haolidong
* @Description: [类的构造函数,分配空间,根节点只要arrow什么也不填]
*/
public Dictionary() {
// TODO Auto-generated constructor stub
arrow = new String("");
name = new String("");
arrDic = new ArrayList<Dictionary>();
}
}
package com.haolidong.Decisiontree;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
public class Decisiontree {
/**
* @param args
* @author haolidong
* @Description: [主函数主要对于各个实例进行测试]
*/
public static void main(String[] args) {
testCreateTree();
testGlass();
}
/**
* @param inputTree 决策树
* @param testVec 测试向量【输入各个特征值进行测试】
* @return 返回最后的标签值
* @author haolidong
* @Description: [主函数主要对于各个实例进行测试]
*/
public static String classify(Dictionary inputTree,ArrayList<String> testVec){
String result = new String();
if(testVec.size()==0){
result=inputTree.name;
}else{
for (int i = 0; i < inputTree.arrDic.size(); i++) {
/*未来防止迭代没有结束,然后已经有返回值,这个时候后面的就不用继续进行了,testVec=0表示的是已经到达了叶子节点*/
if(testVec.size()!=0){
if(testVec.get(0).equals(inputTree.arrDic.get(i).arrow)){
testVec.remove(testVec.get(0));
result=classify(inputTree.arrDic.get(i),testVec);
}
}
}
}
return result;
}
/**
* @param dataSet 数据集
* @param labels 分类的标签值
* @return 返回最终的决策树
* @author haolidong
* @Description: [生成决策树,当遇到标签值全部使用完,但是还是不能够把类完全分开,返回出现最多的标签值;
* 当到达子节点的时候,也要跳出函数,这个分别是前两个if判断,每一次都选择信息增益最大的,
* 然后递归进行划分,每一次递归都要去掉一个标签,一遍递归的终结 。 ]
*/
public static Dictionary createTree(Matrix dataSet,ArrayList<String> labels){
ArrayList<String> classList = new ArrayList<String>();
HashSet<String> setList = new HashSet<String>();
String temps=new String("");
for (int i = 0; i < dataSet.data.size(); i++) {
temps = dataSet.data.get(i).get(dataSet.data.get(i).size()-1);
classList.add(temps);
setList.add(temps);
}
if(setList.size()==1){
Dictionary dtemp = new Dictionary();
dtemp.name = classList.get(0);
return dtemp;
}
if(dataSet.data.get(0).size()==1){
Dictionary stemp = new Dictionary();
stemp.arrow = classList.get(0);
return stemp;
}
int bestFeat = chooseBestFeatureToSplit(dataSet);
String bestFeatLabel = labels.get(bestFeat);
Dictionary myTree = new Dictionary();
myTree.name=bestFeatLabel;
labels.remove(bestFeat);
ArrayList<String> featValues = new ArrayList<String>();
HashSet<String> uniqueVals = new HashSet<String>();
for (int i = 0; i < dataSet.data.size(); i++) {
featValues.add(dataSet.data.get(i).get(bestFeat));
uniqueVals.add(dataSet.data.get(i).get(bestFeat));
}
for (String value : uniqueVals) {
ArrayList<String> subLabels = new ArrayList<String>();
for (int j = 0; j < labels.size(); j++) {
subLabels.add(labels.get(j));
}
Dictionary tempTree = new Dictionary();
tempTree = createTree(splitDataSet(dataSet, bestFeat, value),subLabels);
tempTree.arrow = value;
myTree.arrDic.add(tempTree);
}
return myTree;
}
/**
* @param d
* @author haolidong
* @Description: [对于非叶子节输出他们自己的信息,然后判断字节点,子节点则直接输出]
*/
public static void displayDic(Dictionary d){
if(d.arrDic.size()!=0){
System.out.print("{"+d.name);
if(d.arrDic.size()==0){
System.out.print("}");
}else{
System.out.print(":");
for (int i = 0; i < d.arrDic.size(); i++) {
if(i==0)System.out.print("{");
System.out.print(d.arrDic.get(i).arrow+":");
displayDic(d.arrDic.get(i));
if(i!=d.arrDic.size()-1){
System.out.print(",");
}
}
System.out.print("}");
System.out.print("}");
}
}else {
System.out.print(d.name);
}
}
/**
* @param classList
* @return 返回当前出现次数最多的标签值
* @author haolidong
* @Description: [当且仅当标签全部用完时还没有把类别完全分离才使用的]
*/
public static Dictionary majorityCnt(ArrayList<String> classList){
HashMap<String,Integer> classCount = new HashMap<String,Integer>();
String vote;
for (int i = 0; i < classList.size(); i++) {
vote = classList.get(i);
if(classCount.containsKey(vote)==true){
classCount.put(vote, classCount.get(vote)+1);
}else{
classCount.put(vote, 1);
}
}
ArrayList<HashMap.Entry<String,Integer>> entries= sortMap(classCount);
Dictionary dtemp = new Dictionary();
dtemp.name = entries.get(0).getKey();;
return dtemp;
}
/**
* @param map 输入值是hashmap
* @return 返回排好序的map
* @author haolidong
* @Description: [对map的排序,这里是从大到小]
*/
public static ArrayList<HashMap.Entry<String,Integer>> sortMap(HashMap<String,Integer> map){
List<HashMap.Entry<String, Integer>> entries = new ArrayList<HashMap.Entry<String, Integer>>(map.entrySet());
Collections.sort(entries, new Comparator<HashMap.Entry<String, Integer>>() {
public int compare(HashMap.Entry<String, Integer> obj1 , HashMap.Entry<String, Integer> obj2) {
return obj2.getValue() - obj1.getValue();
}
});
return (ArrayList<Entry<String, Integer>>) entries;
}
/**
* @param DataSet 特征矩阵
* @return 返回需要切分的特征向量的下标
* @author haolidong
* @Description: [根据信息增益,选择最好的切分]
*/
public static int chooseBestFeatureToSplit(Matrix DataSet){
int numFeatures = DataSet.data.get(0).size()-1;
double baseEntropy = calcShannonEnt(DataSet);
double bestInfoGain = 0.0;
int bestFeature=-1;
HashSet<String> uniqueVals = new HashSet<String>();
for (int i = 0; i < numFeatures; i++) {
uniqueVals.clear();
for (int j = 0; j < DataSet.data.size(); j++) {
uniqueVals.add(DataSet.data.get(j).get(i));
}
double newEntropy = 0.0;
double prob = 0.0;
for(String value:uniqueVals){
Matrix subDataSet = new Matrix();
subDataSet = splitDataSet(DataSet, i, value);
prob = 1.0*subDataSet.data.size()/DataSet.data.size();
newEntropy = newEntropy + prob * calcShannonEnt(subDataSet);
}
double infoGain = baseEntropy - newEntropy;
if(infoGain > bestInfoGain){
bestInfoGain = infoGain;
bestFeature = i;
}
}
return bestFeature;
}
/**
* @param DataSet 数据集
* @author haolidong
* @Description: [求香农熵:H=[求和]-p(x)log2 p(x)]
* @return 最后的香农熵
*/
public static double calcShannonEnt(Matrix DataSet){
int numEntries = DataSet.data.size();
HashMap<String,Integer> classCount = new HashMap<String,Integer>();
String currentLabel;
for (int i = 0; i < numEntries; i++) {
currentLabel = DataSet.data.get(i).get(DataSet.data.get(i).size()-1);
if(classCount.containsKey(currentLabel)==true){
classCount.put(currentLabel, classCount.get(currentLabel)+1);
}else{
classCount.put(currentLabel, 1);
}
}
double shannonEnt = 0.0;
double prob = 0.0;
for(HashMap.Entry<String,Integer> entry:classCount.entrySet()){
prob = 1.0*entry.getValue()/numEntries;
shannonEnt =shannonEnt -prob *Math.log(prob)/Math.log(2);
}
return shannonEnt;
}
/**
* @param dataSet 输入数据集
* @param axis 输入删除的列下标
* @param value 把低axis列下标为value的值删除以后,把这一行放入ArrayList
* @return 返回符合第axis列的特征向量为value的矩阵【删除了axis列】
* @author haolidong
* @Description: [返回符合第axis列的特征向量为value的矩阵【删除了axis列]
*/
public static Matrix splitDataSet(Matrix dataSet, int axis, String value){
Matrix retDataSet = new Matrix();
for (int i = 0; i < dataSet.data.size(); i++) {
if(dataSet.data.get(i).get(axis).equals(value)){
ArrayList<String> as = new ArrayList<String>();
for (int j = 0; j < dataSet.data.get(i).size(); j++) {
if(j!=axis){
as.add(dataSet.data.get(i).get(j));
}
}
retDataSet.data.add(as);
}
}
return retDataSet;
}
/**
* @return 返回数据集
* @author haolidong
* @Description: [对香农熵的测试]
*/
public static CreateDataSet testShannon(){
CreateDataSet DataSet = new CreateDataSet();
DataSet.initTest();
System.out.println(calcShannonEnt(DataSet));
return DataSet;
}
/**
* @author haolidong
* @Description: [对分割数据集的测试]
*/
public static void testSplitDataSet() {
CreateDataSet DataSet = new CreateDataSet();
Matrix m =new Matrix();
DataSet.initTest();
m=splitDataSet(DataSet,0,"1");
System.out.println(m);
}
/**
* @author haolidong
* @Description: [对最佳分割数据集的测试]
*/
public static void testChooseBestFeatureToSplit() {
CreateDataSet DataSet = new CreateDataSet();
DataSet.initTest();
System.out.println(chooseBestFeatureToSplit(DataSet));
}
/**
* @author haolidong
* @Description: [对于当标签全部用完时还没有把类别完全分离的函数进行测试]
*/
public static void testmajortityCnt() {
CreateDataSet DataSet = new CreateDataSet();
DataSet.initTest();
ArrayList<String> as = new ArrayList<String>();
for (int i = 0; i < DataSet.data.size(); i++) {
as.add(new String(DataSet.data.get(i).get(DataSet.data.get(i).size()-1)));
}
majorityCnt(as);
}
/**
* @author haolidong
* @Description: [对决策树显示结果的测试]
*/
public static void testDisplayDir() {
Dictionary d1 = new Dictionary();
Dictionary d2 = new Dictionary();
Dictionary d3 = new Dictionary();
Dictionary d4 = new Dictionary();
Dictionary d5 = new Dictionary();
// Dictionary d6 = new Dictionary();
// d6.name="hld";
// d6.arrow="2";
d1.arrow="0";
d1.name="no";
d2.arrow="1";
d2.name="yes";
d3.arrow="1";
d3.name="flippers";
d3.arrDic.add(d1);
d3.arrDic.add(d2);
// d4.arrDic.add(d6);
d4.name="no";
d4.arrow="0";
//root
d5.name="no surfacing";
d5.arrDic.add(d4);
d5.arrDic.add(d3);
displayDic(d5);
}
/**
* @author haolidong
* @Description: [验证决策树的分类效果]
*/
public static void testClassify() {
CreateDataSet DataSet = new CreateDataSet();
ArrayList<String> testVec = new ArrayList<String>();
DataSet.initTest();
Dictionary myTree = new Dictionary();
myTree=createTree(DataSet,DataSet.labels);
testVec.add("1");
testVec.add("0");
// displayDic(myTree);
System.out.println(classify(myTree,testVec));
}
/**
* @author haolidong
* @Description: [对书上最后一个例子的测试【对于隐形眼镜的测试】]
*/
public static void testGlass(){
String fileName = "I:\\machinelearninginaction\\Ch03\\lenses.txt";
File file = new File(fileName);
CreateDataSet DataSet = new CreateDataSet();
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
// 一次读入一行,直到读入null为文件结束
while ((tempString = reader.readLine()) != null) {
// 显示行号
String[] strArr = tempString.split("\t");
ArrayList<String> as = new ArrayList<String>();
for (int i = 0; i < strArr.length; i++) {
as.add(strArr[i]);
}
DataSet.data.add(as);
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
DataSet.labels.add(new String("age"));
DataSet.labels.add(new String("prescript"));
DataSet.labels.add(new String("astigmatic"));
DataSet.labels.add(new String("tearRate"));
Dictionary myTree = new Dictionary();
myTree=createTree(DataSet,DataSet.labels);
displayDic(myTree);
}
/**
* @author haolidong
* @Description: [对建树的测试]
*/
public static void testCreateTree() {
CreateDataSet DataSet = new CreateDataSet();
DataSet.initTest();
Dictionary myTree = new Dictionary();
myTree=createTree(DataSet,DataSet.labels);
displayDic(myTree);
}
}