数据挖掘 -- C4.5决策树算法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/l1832876815/article/details/89469015
1. 算法原理

C4.5算法: 首先根据训练集求出各属性的信息熵info, 然后求出类别信息商infod, infod - info[i]得到每个属性的信息增益gain, 然后计算每个属性的信息分裂度h, gain[i] / h[i]得到属性信息增益率。递归选择信息增益率最高的属性,按照该属性对数据集进行分裂,判断分裂之后的数据集类别是否为’纯’的,如果是则将当前分裂属性作为叶节点,如果不是继续递归进行分裂过程。最终训练出一颗决策树。测试过程即根据各属性的值遍历决策树,直到到达叶节点,叶节点的类别即为该测试样例的类别。

2. 代码实现

Node.java

package com.clxk1997;

/**
 * @Description 决策树节点
 * @Author Clxk
 * @Date 2019/4/22 14:16
 * @Version 1.0
 */
public class Node {

    private String field;
    private String value;

    public Node() {

    }

    public Node(String field, String value) {
        this.field = field;
        this.value = value;
    }

    public String getField() {
        return field;
    }

    public void setField(String field) {
        this.field = field;
    }

    public String getValue() {
        return value;
    }

    public void setValue(String value) {
        this.value = value;
    }
}

DecisionTree.java

package com.clxk1997;

import java.util.ArrayList;

/**
 * @Description 决策树
 * @Author Clxk
 * @Date 2019/4/22 13:57
 * @Version 1.0
 */
public class DecisionTree {

    private Node node;
    private ArrayList<DecisionTree> childs;

    public Node getNode() {
        return node;
    }

    public void setNode(Node node) {
        this.node = node;
    }

    public ArrayList<DecisionTree> getChilds() {
        return childs;
    }

    public void setChilds(ArrayList<DecisionTree> childs) {
        this.childs = childs;
    }

    public static DecisionTree init() {
        DecisionTree decisionTree = new DecisionTree();
        decisionTree.setNode(new Node("root", "root"));
        decisionTree.setChilds(new ArrayList<DecisionTree>());
        return decisionTree;
    }
}

C45.java

package com.clxk1997;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.util.*;

/**
 * C4.5决策树算法
 * @author clxk
 *
 */
public class C45 {

    private static final int MAXN = 0x3f;

    /**
     * 属性集合
     */
    private static List<String> fields = new ArrayList<>();
    private static List<String> curfields = new ArrayList<>();

    /**
     * 类别集合
     */
    private static List<String> classfields = new ArrayList<>();

    /**
     * 训练集
     */
    private static List<String>[] trains = new ArrayList[MAXN];

    /**
     * 测试集
     */
    private static List<String>[] tests = new ArrayList[MAXN];

    /**
     * 训练集数量
     */
    private static int count_train;

    /**
     * 测试集数量
     */
    private static int count_test;

    /**
     * 决策树初始化
     */
    private static DecisionTree tree = DecisionTree.init();
    static String ans = null;

    /**
     * 数据集输入
     */
    public static void input() throws Exception{

        System.out.println("请输入属性集合的数量: ");
        Scanner scan = new Scanner(new FileInputStream(new File("lib/input_train.txt")));
        int cnt = scan.nextInt();
        System.out.println(cnt);
        System.out.println("请输入" + cnt + "个属性: ");
        String str;
        scan.nextLine();
        for(int i = 0; i < cnt; i++) {
            str = scan.nextLine();
            System.out.println(str);
            fields.add(str);
            curfields.add(str);
        }

        System.out.println("请输入类别集合的数量: ");
        cnt = scan.nextInt();
        scan.nextLine();
        System.out.println(cnt);
        System.out.println("请输入" + cnt + "个类别标签: ");
        for(int i = 0; i < cnt; i++) {
            str = scan.nextLine();
            System.out.println(str);
            classfields.add(str);
        }

        System.out.println("请输入训练集的数量: ");
        cnt = scan.nextInt();
        count_train = cnt;
        System.out.println(cnt);
        scan.nextLine();
        System.out.println("请输入训练集: ");
        for(int i = 0; i < cnt; i++) {
            trains[i] = new ArrayList<>();
            str = scan.nextLine();
            System.out.println(str);
            String[] split = str.split("\\s");
            for(String s: split) {
                trains[i].add(s);
            }
        }

    }

    /**
     * 计算信息熵
     * @return
     */
    public static double[] getInfos(List<String>[] trains, List<String> fields, int count_train) {

        double[] infos = new double[MAXN];

        /**
         * 获取属性信息熵
         */
        for(int i = 0; i < fields.size(); i++) {
            infos[i] = getInfo(i, false, trains, count_train);
        }

        for(int i = 0; i < classfields.size(); i++) {
            infos[i+fields.size()] = getInfo(i, true, trains, count_train);
        }

        return infos;

    }

    /**
     * 训练
     */
    public static void training()  throws Exception{

        /**
         * 训练集输入
         */
        input();
        /**
         * 递归构建树
         */
        train(curfields, trains, count_train, tree);

        /**
         * 输出树
         */
        showTree(tree);

        /**
         * 测试集输入
         */
        inputTest();
    }

    /**
     * 测试集输入
     */
    public static void inputTest() throws FileNotFoundException {

        System.out.println("ce" + fields.size());
        System.out.println("请输入测试集的个数: ");
        Scanner scanner = new Scanner(new FileInputStream(new File("lib/input_test.txt")));
        count_test = scanner.nextInt();
        String str;
        scanner.nextLine();
        Map<String, String> mp = new HashMap<>();
        for(int i = 0; i < count_test; i++) {
            mp.clear();
            System.out.println("请输入第" + i + "个测试样例: ");
            str = scanner.nextLine();
            String[] split = str.split("\\s");
            for (int j = 0; j < fields.size(); j++) {
                mp.put(fields.get(j), split[j].trim());
            }

            String ans = getAns(mp);
            System.out.println("在该测试样例情况下," + classfields.get(0) + " 应为 " + ans);
        }
    }

    /**
     * 返回样例的类别值
     * @param mp
     * @return
     */
    public static String getAns(Map<String, String> mp) {
        dfsTree(tree, mp);
        return ans;
    }

    /**
     * 遍历决策树,寻找样例对应的叶节点
     *
     * @param tree
     * @param mp
     * @return
     */
    public static void dfsTree(DecisionTree tree, Map<String, String> mp) {
        if(tree.getNode().getField().equals("叶节点")){
            ans  = tree.getNode().getValue();
            return;
        } else if(tree.getChilds().size() == 1 && tree.getChilds().get(0).getNode().getField().equals("叶节点")) {
            ans = tree.getChilds().get(0).getNode().getValue();
            return;
        } else {
            for(int i = 0; i < tree.getChilds().size(); i++) {
                String field = tree.getChilds().get(i).getNode().getField();
                String value = tree.getChilds().get(i).getNode().getValue();
                if(mp.containsKey(field) && mp.get(field).equals(value)) {
                    dfsTree(tree.getChilds().get(i),mp);
                }
            }
        }
    }

    /**
     *
     * @param tree
     */
    public static void showTree(DecisionTree tree) {

        System.out.print("当前节点属性为: " + tree.getNode().getField() + "   当前属性值为: " +
                tree.getNode().getValue());
        System.out.println();
        if(tree.getChilds() == null) return;
        for(int i = 0; i < tree.getChilds().size(); i++) {
            System.out.println("子节点属性为: " + tree.getChilds().get(i).getNode().getField()
             + "   子节点属性值为: " + tree.getChilds().get(i).getNode().getValue());
        }
        for(int i = 0; i < tree.getChilds().size(); i++) {
            showTree(tree.getChilds().get(i));
        }
    }

    /**
     * 递归构建树
     * @param fields
     * @param trains
     * @param count_train
     * @param tree
     */
    public static void train(List<String> fields, List<String>[] trains, int count_train, DecisionTree tree) {

        if(tree.getChilds() == null) {
            tree.setChilds(new ArrayList<>());
        }

        if(fields.size() == 0) return;
        if(isPure(trains)) {
            DecisionTree decisionTree = new DecisionTree();
            decisionTree.setNode(new Node("叶节点", trains[0].get(trains[0].size()-1)));
            tree.getChilds().add(decisionTree);
            return;
        }

        /**
         * 计算信息熵
         */
        System.out.println("计算得到所有属性的信息熵为: ");
        double[] infos = getInfos(trains, fields, count_train);
        for(int i = 0; i < fields.size(); i++) {
            System.out.format("%s : %.3f\n", fields.get(i), infos[i]);
        }
        System.out.println("计算得到类别D的信息熵为: ");
        for(int i = 0; i < classfields.size(); i++) {
            System.out.format("%s : %.3f\n", classfields.get(i), infos[i+ fields.size()]);
        }

        /**
         * 计算信息增益
         */
        System.out.println("计算得到所有属性的信息增益为: ");
        double[] gains = getGains(infos, infos[fields.size()], fields);
        for(int i = 0; i < fields.size(); i++) {
            System.out.format("%s : %.3f\n", fields.get(i), gains[i]);
        }

        /**
         * 计算属性分裂信息度量
         */
        System.out.println("计算得到所有属性的分裂信息度量为: ");
        double[] h  = new double[MAXN];
        for(int i = 0; i < fields.size(); i++) {
            h[i] = getH(i, trains, count_train);
        }
        for(int i = 0; i < fields.size(); i++) {
            System.out.format("%s : %.3f\n", fields.get(i), h[i]);
        }

        /**
         * 计算信息增益率
         */
        System.out.println("计算得到所有属性的信息增益率为");
        double[] igr  = new double[MAXN];
        igr = getIGR(gains, h, fields);
        for(int i = 0; i < fields.size(); i++) {
            System.out.format("%s : %.3f\n", fields.get(i), igr[i]);
        }

        /**
         * 找到分裂属性
         */
        int index = 0;
        double maxd = 0.0;
        for (int i = 0; i < fields.size(); i++) {
            if(igr[i] > maxd) {
                maxd = igr[i];
                index = i;
            }
        }
        String field = fields.get(index);
        System.out.println("分裂属性为: " + field);
        fields.remove(index);
        Map<String, ArrayList<Integer>> ids = new HashMap<>();

        for(int i = 0; i < trains.length; i++) {
            if(trains[i] == null) break;
            List<String> train = trains[i];
            if(!ids.containsKey(train.get(index))) {
                ArrayList<Integer> integers = new ArrayList<>();
                integers.add(i);
                ids.put(train.get(index), integers);
            }else {
                ids.get(train.get(index)).add(i);
            }
        }
        for(Map.Entry<String, ArrayList<Integer>>entry : ids.entrySet()) {
            ArrayList<String> array[] = new ArrayList[MAXN];
            ArrayList<Integer> value = entry.getValue();
            for(int i = 0; i < value.size(); i++) {
                int cur = value.get(i);
                trains[cur].remove(index);
                array[i] = (ArrayList<String>) trains[cur];
            }
            DecisionTree decisionTree = new DecisionTree();
            decisionTree.setNode(new Node(field, entry.getKey()));
            tree.getChilds().add(decisionTree);
            train(fields, array, value.size(), decisionTree);
        }
    }

    /**
     * 判断当前节点是不是纯节点
     * @param trains
     * @return
     */
    public static boolean isPure(List<String>[] trains) {
        Set<String> set = new HashSet<>();
        for(int i = 0; i < trains.length; i++) {
            if(trains[i] == null) break;
            set.add(trains[i].get(trains[i].size()-1));
            if(set.size() > 1) return false;
        }
        return true;
    }


    /**
     * 计算属性信息熵
     */
    public static double getInfo(int index, boolean isClass, List<String>[] trains, int count_train) {

        double ans = 0;
        Map<String, Integer> fi = new HashMap<>();//field->cnt
        Map<String, Map<String, Integer>> cl = new HashMap<>();//field->class

        if(isClass) {
            for(int i = 0; i < trains.length; i++) {
                if(trains[i] == null) break;
                ArrayList<String> arr = (ArrayList<String>) trains[i];
                String field = arr.get(arr.size() - 1);
                if(fi.containsKey(field)) {
                    fi.put(field, fi.get(field) + 1);
                } else {
                    fi.put(field, 1);
                }
            }

            for(Map.Entry<String, Integer>entry : fi.entrySet()) {
                double div = (double)entry.getValue() / (double)count_train;
                ans -= div * (Math.log(div)/Math.log((double)2));
            }
        } else {
            for(int i = 0; i < trains.length; i++) {
                int t = 0;
                Map<String, Integer> curmap = new HashMap<>();
                if(trains[i] == null) break;
                ArrayList<String> arr = (ArrayList<String>) trains[i];
                if(fi.containsKey(arr.get(index))) {
                    t = fi.get(arr.get(index));
                    fi.put(arr.get(index), ++t);
                    curmap = cl.get(arr.get(index));
                    if(curmap.containsKey(arr.get(arr.size()-1))) {
                        t = (int) curmap.get(arr.get(arr.size()-1));
                        curmap.put(arr.get(arr.size()-1), t+1);
                    } else {
                        curmap.put(arr.get(arr.size()-1), 1);
                    }
                    cl.put(arr.get(index), curmap);
                } else {
                    fi.put(arr.get(index), 1);
                    curmap.put(arr.get(arr.size() - 1), 1);
                    cl.put(arr.get(index), curmap);
                }
            }

            for(Map.Entry<String, Integer>entry : fi.entrySet()) {
                double curans = 0;
                String fie = entry.getKey();
                Map<String, Integer> curmap = cl.get(fie);
                for(Map.Entry<String, Integer>en : curmap.entrySet()) {
                    double div = (double)en.getValue() / (double)entry.getValue();
                    curans -= div * (Math.log(div)/Math.log((double)2));
                }
                curans *= (double)entry.getValue() / (double)count_train;
                ans += curans;
            }

        }
        return ans;
    }

    /**
     * 计算属性信息增益
     * @return
     * @param infos
     * @param d
     */
    public static double[] getGains(double[] infos, double d, List<String> fields) {
        double[] gain = new double[MAXN];
        for(int i = 0; i <fields.size(); i++) {
            gain[i] = d - infos[i];
        }
        return gain;
    }

    /**
     * 计算属性分裂信息度量
     * @param index
     * @return
     */
    public static double getH(int index, List<String>[] trains, int count_train) {

        double ans = 0;
        Map<String, Integer> fi = new HashMap<>();
        for(List<String> array : trains) {
            if(array == null) break;
            String name = array.get(index);
            if(fi.containsKey(name)) {
                fi.put(name, fi.get(name)+1);
            } else {
                fi.put(name, 1);
            }
        }

        for(Map.Entry<String, Integer>entry : fi.entrySet()) {
            double div = (double)entry.getValue() / (double)count_train;
            ans -= div * (Math.log(div) / Math.log((double)2));
        }
        return ans;
    }

    /**
     * 计算信息增益率
     * @param gains
     * @param h
     * @return
     */
    public static double[] getIGR(double[] gains, double[] h, List<String> fields) {
        double[] ans = new double[MAXN];
        for(int i = 0; i < fields.size(); i++) {
            ans[i] = gains[i] / h[i];
        }
        return ans;
    }


    public static void main(String[] args) {
        try {
            training();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}

input_train.txt

4
天气
温度
湿度
是否有风
1
是否适合打网球
10
晴 热 高 否 否
晴 热 高 是 否
阴 热 高 否 是
雨 温 高 否 是
雨 凉爽 中 否 是
雨 凉爽 中 是 否
阴 凉爽 中 是 是
晴 温 高 否 否
晴 凉爽 中 否 是
雨 温 中 否 是

input_test.txt

4
晴 温 中 是
阴 温 高 是
阴 热 中 否
雨 温 高 是

猜你喜欢

转载自blog.csdn.net/l1832876815/article/details/89469015