机器学习算法——决策树ID3算法介绍以及Java实现

一、 决策树算法

决策树:是一种用于对实例进行分类的树形结构,可以是二叉树或非二叉树,由节点(node)和有向边(directed edge)组成。其中每个非叶子节点表示一个特征属性,叶子节点代表类别属性,它的值由根节点到叶子节点这一分支的属性值确定。使用决策树进行分类的过程,就是从根节点出发,训练数据的分支走向,直到得到叶子节点的值停止计算,这时即可输出类别。

决策树算法是从数据的属性(或者特征)出发,以属性作为基础,划分不同的类。实现决策树的算法有很多种,有ID3、C4.5和CART等算法。下面我们介绍ID3算法。

二、ID3算法

ID3算法是由Quinlan首先提出的,该算法是以信息论为基础,以信息熵和信息增益为衡量标准,从而实现对数据的归纳分类。

算法原理:设
这里写图片描述
为训练样本,对于每个样本有m个属性,用A,B,C,……来表示每一个属性。样本类别为
这里写图片描述
算法计算步骤如下:
首先,构造分类决策树。
(1)计算训练样本D的信息熵,即
这里写图片描述
Pi 表示第i个类别个数占训练样本总数的比例。
(2)分别计算每个属性的条件熵,例如计算属性A相对于D的期望信息,即:
这里写图片描述
其中,
这里写图片描述
表示属性A将D划分为v个子集。
这里写图片描述
表示第j个子集的样本数比上样本总数。
(3)计算属性A的信息增益,即
这里写图片描述
同理,重复第二步、第三步,计算出其他属性的信息增益,直至所有属性计算完成。选择信息增益最大的属性为根节点,对样本集D进行第一次分裂;然后对余下的属性重复上述的步骤,选择作为第二层节点的属性,直到找到叶子节点为止。此时,就构造出一颗决策分类树。

决策树构造完成后,就可以对待分类的样本进行分类,得到类别。

三、ID3算法实例讲解

图中数据为训练样本,现在预测E= {天气=晴,温度=适中,湿度=正常,风速=弱} 的情况下活动是取消还是进行。属性为天气、温度、湿度、风速,类别为取消和进行。
这里写图片描述

样本D的信息熵,
这里写图片描述
在天气为晴时有5种情况,发现活动取消有3种,进行有2种,计算现在的条件熵:
这里写图片描述
天气为阴时有4种情况,活动进行的有4种,则条件熵为:
这里写图片描述
天气为雨时有5种情况,活动取消的有2种,进行的有3种,则条件熵为:
这里写图片描述
由于按照天气属性不同取值划分时,天气为晴占整个情况的5/14,天气为阴占整个情况的4/14,天气为雨占整个情况的5/14,则按照天气属性不同取值划分时的带权平均值熵为:
这里写图片描述
算出的结果约为0.693.

则此时的信息增益Gain(活动,天气)= H(活动) - H(活动|天气) = 0.94- 0.693 = 0.246

同理我们可以计算出按照温度属性不同取值划分后的信息增益:

Gain(活动,温度)= H(活动) - H(活动|温度) = 0.94- 0.911 = 0.029

按照湿度属性不同取值划分后的信息增益:

Gain(活动,湿度)= H(活动) - H(活动|湿度) = 0.94- 0.789 = 0.151

按照风速属性不同取值划分后的信息增益:

Gain(活动,风速)= H(活动) - H(活动|风速) = 0.94- 0.892 = 0.048

决策树的构造就是要选择当前信息增益最大的属性来作为当前决策树的节点。因此我们选择天气属性来做为决策树根节点,这时天气属性有3取值可能:晴,阴,雨,我们发现当天气为阴时,活动全为进行因此这件事情就可以确定了,而天气为晴或雨时,活动中有进行的也有取消的,事件还无法确定,这时就需要在剩下的属性中递归再次计算活动熵和信息增益,选择信息增益最大的属性来作为下一个节点,直到整个事件能够确定下来。

最后得到的决策树如下图所示:
这里写图片描述
所以,E= {天气=晴,温度=适中,湿度=正常,风速=弱} 的情况下活动是进行。

四、ID3算法Java实现

下面是实例的Java代码实现,算法实现前,需要转换为数字向量,其中,天气有晴,阴,雨,可分配2,1,0三个值,即晴=2,阴=1,雨=0;同理,风速,强= 1,弱=0;湿度,高=1,正常=0;温度,炎热=2,正常=1,寒冷=0。

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

/**
 *
 * @author X.H.Yang
 */
public class TreeID3Act {

    public static String[] createDataLable() {
        String lable[] = {"weather", " temper", "humidity", "wind"};
        return lable;
    }

    public static Object[][] createDataSet() {
        Object set[][] = {{2, 2, 1, 0,"no"},
        {2, 2, 1, 1, "no"},
        {1, 2, 1, 0, "yes"},
        {0, 1, 1, 0,"yes"},
        {0, 0, 0, 0, "yes"},
        {0, 0, 0, 1, "no"},
        {1, 0, 0, 1, "yes"},
        {2, 1, 1, 0, "no"},
        {2, 0, 0, 0, "yes"},
        {0, 1, 0, 0, "yes"},
        {2, 1, 0, 1, "yes"},
        {1, 1, 1, 1, "yes"},
        {1, 2, 0, 0, "yes"},
        {0, 1, 1, 1, "no"}};
        return set;
    }

    public static double calcShannonEnt(Object[][] dataSet) {
        double shannonEnt = 0.0;
        int numEntries = dataSet.length;
        HashMap<String, Integer> labelCounts = new HashMap<>();
        for (Object[] featVec : dataSet) {
            String currentLabel = (String) featVec[featVec.length - 1];
            if (labelCounts.get(currentLabel) == null) {
                labelCounts.put(currentLabel, 1);
            } else {
                int i = labelCounts.get(currentLabel);
                i++;
                labelCounts.put(currentLabel, i);
            }
        }
        for (Entry<String, Integer> entry : labelCounts.entrySet()) {
            double prob = (double) entry.getValue() / (double) numEntries;
            shannonEnt -= prob * (Math.log(prob) / Math.log(2.0d));
        }
        return shannonEnt;
    }

    public static ArrayList<Object[]> splitDataSet(Object dataSet[][], int axis, Object value) {
        ArrayList<Object[]> retDataSet = new ArrayList<>();
        for (Object[] featVec : dataSet) {
            Object subSet[] = null;
            if (featVec[axis].equals(value)) {
                subSet = new Object[dataSet[0].length - 1];
                if (axis == 0) {
                    System.arraycopy(featVec, 1, subSet, 0, subSet.length);
                } else {
                    System.arraycopy(featVec, 0, subSet, 0, axis);
                    System.arraycopy(featVec, axis + 1, subSet, axis, subSet.length - axis);
                }
                retDataSet.add(subSet);
            }
        }
        return retDataSet;
    }

    public static int chooseBestFeatureToSplit(Object dataSet[][]) {
        int numFeatures = dataSet[0].length - 1;
        double baseEntropy = calcShannonEnt(dataSet);
        double bestInfoGain = 0.0;
        int bestFeature = -1;
        for (int f = 0; f < numFeatures; f++) {
            HashSet<Object> uniqueVals = new HashSet<>();
            for (Object set[] : dataSet) {
                uniqueVals.add(set[f]);
            }
            double newEntropy = 0.0;
            for (Object obj : uniqueVals) {
                ArrayList<Object[]> subDataSet = splitDataSet(dataSet, f, obj);
                double prob = (double) subDataSet.size() / (double) dataSet.length;
                Object subSetArray[][] = new Object[subDataSet.size()][numFeatures];
                for (int i = 0; i < subSetArray.length; i++) {
                    System.arraycopy(subDataSet.get(i), 0, subSetArray[i], 0, numFeatures);
                }
                newEntropy += prob * calcShannonEnt(subSetArray);
            }
            double infoGain = baseEntropy - newEntropy;
            if (infoGain > bestInfoGain) {
                bestInfoGain = infoGain;
                bestFeature = f;
            }
        }
        return bestFeature;
    }

    public static String majorityCnt(String classList[]) {
        HashMap<String, Integer> classCount = new HashMap<>();
        for (String vote : classList) {
            if (!classCount.containsKey(vote)) {
                classCount.put(vote, 1);
            } else {
                int i = classCount.get(vote);
                i++;
                classCount.put(vote, i);
            }
        }
        LinkedHashMap<String, Integer> sortMap = sortMapByValues(classCount);
        return sortMap.entrySet().iterator().next().getKey();
    }

    private static Object createTree(Object dataSet[][], String labels[]) {
        //classList = [example[-1] for example in dataSet]
        ArrayList<String> classList = new ArrayList();
        for (Object set[] : dataSet) {
            classList.add((String) set[set.length - 1]);
        }
        if (ListCount(classList, classList.get(0)) == classList.size()) {
            return classList.get(0);
        }
        if (dataSet[0].length == 1) {
            return majorityCnt((String[]) classList.toArray());
        }
        int bestFeat = chooseBestFeatureToSplit(dataSet);
        String bestFeatLabel = labels[bestFeat];
        HashMap myTree = new HashMap<>();
        myTree.put(bestFeatLabel, new HashMap<>());
        String sublabels[] = new String[labels.length - 1];
        if (bestFeat == 0) {
            System.arraycopy(labels, 1, sublabels, 0, sublabels.length);
        } else {
            System.arraycopy(labels, 0, sublabels, 0, bestFeat);
            System.arraycopy(labels, bestFeat + 1, sublabels, bestFeat, sublabels.length - bestFeat);
        }
        HashSet<Object> uniqueVals = new HashSet<>();
        for (Object set[] : dataSet) {
            uniqueVals.add(set[bestFeat]);
        }
        for(Object value : uniqueVals){
            ArrayList<Object[]> setlist = splitDataSet(dataSet, bestFeat, value);
            int j = 0;
            Object dataSetM[][] =new Object[setlist.size()][];
            for(Object set[]:setlist){
                dataSetM[j] = set;
                j++;
            }
            Object tree = createTree(dataSetM,sublabels);
            HashMap subtree = (HashMap) myTree.get(bestFeatLabel);
            subtree.put(value, tree);
        }
        return myTree;
    }

    private static String classify(HashMap<Object,Object> inputTree,String featLabels[],Object[] testVec){
        String firstStr = (String) inputTree.keySet().iterator().next();
        HashMap secondDict = (HashMap) inputTree.get(firstStr);
        int featIndex = 0;
        for(featIndex=0;featIndex<featLabels.length;featIndex++){
            if(featLabels[featIndex].equals(firstStr)){
                break;
            }
        }
        Object key = testVec[featIndex];
        Object valueOfFeat = secondDict.get(key);
        String classLabel ="erro";
        if(valueOfFeat instanceof HashMap){
            classLabel = classify((HashMap<Object, Object>) valueOfFeat, featLabels, testVec);
        }else{
            classLabel = (String)valueOfFeat;
        }
        return classLabel;
    }

    private static int ListCount(ArrayList<String> classList, String key) {
        int c = 0;
        for (String clazz : classList) {
            if (clazz.equals(key)) {
                c++;
            }
        }
        return c;
    }

    private static LinkedHashMap<String, Integer> sortMapByValues(Map<String, Integer> aMap) {

        Set<Entry<String, Integer>> mapEntries = aMap.entrySet();

        //System.out.println("Values and Keys before sorting ");
        //for (Entry<String, Integer> entry : mapEntries) {
        //System.out.println(entry.getValue() + " - " + entry.getKey());
        //}
        // used linked list to sort, because insertion of elements in linked list is faster than an array list. 
        List<Entry<String, Integer>> aList = new LinkedList<Entry<String, Integer>>(mapEntries);

        // sorting the List 
        Collections.sort(aList, new Comparator<Entry<String, Integer>>() {

            @Override
            public int compare(Entry<String, Integer> ele1, Entry<String, Integer> ele2) {

                return -(ele1.getValue().compareTo(ele2.getValue()));
            }
        });

        // Storing the list into Linked HashMap to preserve the order of insertion. 
        LinkedHashMap<String, Integer> aMap2 = new LinkedHashMap<String, Integer>();
        for (Entry<String, Integer> entry : aList) {
            aMap2.put(entry.getKey(), entry.getValue());
        }
        return aMap2;
    }

    public static void main(String[] args) throws Exception {
        Object Dataset[][] = createDataSet();
        Object tree = createTree(Dataset,new String[]{"weather", " temper", "humidity", "wind"});
        System.out.println(tree);
        String clazz = classify((HashMap<Object, Object>) tree,new String[]{"weather", " temper", "humidity", "wind"},new Object[]{2,1,0,0});
        System.out.println(clazz);
    }
}

程序输出结果为:{weather={0={wind={0=yes, 1=no}}, 1=yes, 2={humidity={0=yes, 1=no}}}}
测试样本结果:yes

猜你喜欢

转载自blog.csdn.net/xiaoxiao_yang77/article/details/79262704