决策树分类算法:C4.5算法

决策树分类算法:C4.5算法

【每次以信息增益率最大的特征项Ai为节点建立决策树】
【决策树算法思路参考】

决策树分类算法公共基类

```java
package base;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public abstract class GeneralDecTreeHandler {
	
	protected abstract int chooseBestFeatureToSplit(Matrix_2D<String> dataSet);
	
	//划分数据集,提取第index维值等于val的行,并去除第index维
	protected Matrix_2D<String> splitDataSet(Matrix_2D<String> dataSet,int index,String val) {
		Matrix_2D<String> retDataSet=new Matrix_2D<String>();
		int row=dataSet.getRowDimension();
		for(int i=0;i<row;++i) {
			if(dataSet.get(i).get(index).equals(val)) {
				ArrayList<String> tempLine=new ArrayList<String>();
				for(int p=0;p<dataSet.get(i).size();++p)
					if(p!=index) tempLine.add(dataSet.get(i).get(p));
				retDataSet.putLine(tempLine);
			}
		}
		return retDataSet;
	}
	
	private String majorityClassificationCount(String[] labels) {
		Map<String, Integer> labelCount=new HashMap<String, Integer>();
		for(String s : labels) {
			if(!labelCount.containsKey(s)) labelCount.put(s,0);
			labelCount.put(s,labelCount.get(s)+1);
		}
		int count=-1;
		String t="";
		for(String s : labelCount.keySet()) {
			if(labelCount.get(s)>count) {
				count=labelCount.get(s);
				t=s;
			}
		}
		return t;
	}
	
	public TreeNode creaDecTree(Matrix_2D<String> dataSet,String[] features) {
		final int row=dataSet.getRowDimension(),col=dataSet.getColDimension();
		String[] labelsList=new String[row];
		for(int i=0;i<row;++i) {
			labelsList[i]=dataSet.get(i, col-1);
		}
		int num=0;
		for(String s : labelsList)
			if(s.equals(labelsList[0])) ++num;
		if(num==labelsList.length) return new TreeNode(labelsList[0],null);//只含一类
		if(col==1) return new TreeNode(majorityClassificationCount(labelsList),null);
		int bestFeature=chooseBestFeatureToSplit(dataSet);
		String bestFeatureLabel=features[bestFeature];		
		//去掉bestFeature的features
		String[] subFeatures=subArray(features, bestFeatureLabel);
		Set<String> uniqFeatureVals=new HashSet<String>();//存储值不重复,无序
		for(int i=0;i<row;++i) uniqFeatureVals.add(dataSet.get(i).get(bestFeature));
		Map<String, TreeNode> child=new HashMap<String, TreeNode>();
		for(String s : uniqFeatureVals) {
			child.put(s,creaDecTree(splitDataSet(dataSet, bestFeature, s), subFeatures));
		}
		return new TreeNode(bestFeatureLabel,child);
	}
	
	private String[] subArray(String[] original,String str) {
		String[] subArray=new String[original.length-1];
		int k=0;
		for(String s : original) {
			if(!s.equals(str)) subArray[k++]=s;
		}
		return subArray;
	}
	
	public String classification(TreeNode tree,String[] features,ArrayList<String> sample) {
		while(tree!=null&&tree.getChildren()!=null) {
			try {
				//System.out.println(tree.element+"\t"+tree.child.size());
				tree=tree.getChildren().get(sample.get(getIndex(features, (String)tree.getElement())));
			} catch (Exception e) {
				// TODO: handle exception
				e.printStackTrace();
				return "no such classification";
			}
		}
		if(tree==null) return "no such classification";
		return (String)tree.getElement();
	}
	private int getIndex(String[] arr,String s) {
		for(int i=0;i<arr.length;++i)
			if(arr[i].equals(s)) return i;
		return-1;
	}
	
	public static Matrix_2D<String> readDataFile(String path) throws IOException {
		ArrayList<ArrayList<String>> trainingSet=new ArrayList<ArrayList<String>>();
		File file=new File(path);
		if(!file.exists()||!file.isFile()) {
			System.out.println(file.getAbsolutePath());
			return null;
		}
		BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
        String str = "";  
        while ((str=reader.readLine())!=null) {
            String[] tokenizer = str.split(",");  
            ArrayList<String> s = new ArrayList<String>();  
            for(int i=0;i<tokenizer.length;i++){
                s.add(tokenizer[i]);
            }
            trainingSet.add(s);  
        }  
        reader.close();
        //打乱数据集
        for(int i=0;i<trainingSet.size();++i) {
        	int t=(int) ((trainingSet.size()-i)*Math.random());
        	trainingSet.add(trainingSet.remove(t));
        }
		return new Matrix_2D<String>(trainingSet);
	}
	
	private double report(TreeNode tree,String[] features,Matrix_2D<String> samples) {
		int num=0;
		for(int i=0;i<samples.getRowDimension();++i) {
			if(classification(tree, features, samples.get(i)).equals(samples.get(i,samples.getColDimension()-1)))
				++num;
		}
		return num/(double)samples.getRowDimension();
	}
	
	public double reportModel(String[] features,Matrix_2D<String> dataSet,boolean self) {
		if(self) return report(creaDecTree(dataSet, features), features, dataSet);
		else {//默认1:1
			Matrix_2D<String> training=new Matrix_2D<String>();
			Matrix_2D<String> test=new Matrix_2D<String>();
			for(int i=0;i<dataSet.getRowDimension();++i) {
				if(Math.random()>=0.50) training.putLine(dataSet.get(i));
				else test.putLine(dataSet.get(i));
			}
			return report(creaDecTree(training, features), features, test);
		}
	}
}

C4.5算法选择特征项及测试

```java
package c4_5;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;

import base.GeneralDecTreeHandler;
import base.Matrix_2D;
import base.TreeNode;

public class C4_5 extends GeneralDecTreeHandler {

	@Override
	protected int chooseBestFeatureToSplit(Matrix_2D<String> dataSet) {
		// TODO 自动生成的方法存根
		final int col=dataSet.getColDimension();
		double infoGainRation,bestIGR=0.0;
		int bestFeature=-1;
		for(int i=0;i<col-1;++i) {
			infoGainRation=calInfoGainRation(dataSet, i);
			if(infoGainRation>bestIGR) {
				infoGainRation=bestIGR;
				bestFeature=i;
			}
		}
		return bestFeature;
	}
	
	private double calInfoGainRation(Matrix_2D<String> ds,int index) {
		final int row=ds.getRowDimension();
		final double baseEntropy=calShannonEntropy(ds);
		Set<String> featureSet=new HashSet<String>();
		for(int i=0;i<row;++i) featureSet.add(ds.get(i, index));
		double newEntropy=0.0,pro;
		for(String s : featureSet) {
			Matrix_2D<String> retData=splitDataSet(ds, index, s);
			pro=retData.getRowDimension()/(double)row;
			newEntropy+=pro*calShannonEntropy(retData);
		}
		return (baseEntropy-newEntropy)/calSplitInformation(ds, index);
	}
	
	private double calSplitInformation(Matrix_2D<String> ds,int index) {
		final int m = ds.getRowDimension();
        String currentLabel = "";
        double splitInfo = 0.0;
        double rate = 0;
        HashMap<String,Integer> labelCounts = new HashMap<String, Integer>();
        //统计各类出现次数
        for(int i=0;i<m;i++){
            currentLabel = ds.get(i,index);
            if(!labelCounts.containsKey(currentLabel))
                labelCounts.put(currentLabel,0);
            labelCounts.put(currentLabel,labelCounts.get(currentLabel)+1);
        }
        //计算整体香农熵
        for(String key:labelCounts.keySet()){
            rate =labelCounts.get(key)/(double)m;
            splitInfo -= rate*Math.log(rate)/Math.log(2.0);
        }
        return splitInfo;
	}
	
	public static double calShannonEntropy(Matrix_2D<String> ds) {
		int m = ds.getRowDimension();
        int n = ds.getColDimension();
        String currentLabel = "";
        double shannonEnt = 0;
        double rate = 0;
        HashMap<String,Integer> labelCounts = new HashMap<String, Integer>();
        //统计各类出现次数
        for(int i=0;i<m;i++){
            currentLabel = ds.get(i,n-1);
            if(!labelCounts.containsKey(currentLabel))
                labelCounts.put(currentLabel,0);
            labelCounts.put(currentLabel,labelCounts.get(currentLabel)+1);
        }
        //计算整体香农熵
        for(String key:labelCounts.keySet()){
            rate =labelCounts.get(key)/(float)m;
            shannonEnt -= rate*Math.log(rate)/Math.log(2);
        }
        return shannonEnt;
	}
	
	public static void main(String[] args) throws IOException {//divorce.txt/
		C4_5 tool=new C4_5();//AutismAdultDataPlus.txt/StudentAcademicsPerformance.txt
		Matrix_2D<String> trainingSet=C4_5.readDataFile("AutismAdultDataPlus.txt");
		String[] features=new String[trainingSet.getColDimension()-1];
		for(int i=0;i<features.length;++i)
			features[i]="特征"+String.valueOf(i);
		TreeNode tree=tool.creaDecTree(trainingSet, features);
		int num=0;
		final int row=trainingSet.getRowDimension(),col=trainingSet.getColDimension();
		for(int i=0;i<row;++i) {
			String tmp=tool.classification(tree, features, trainingSet.get(i));
			if(tmp.equals(trainingSet.get(i).get(col-1))) {
				++num;
			}
		}
		System.out.println("测试实例数:"+row+",分类正确数:"+num+",分类精度:"+(num/(double)row));
		System.out.println("*30次1:1模型测试:");
		double pp=0.0,p0;
		for(int i=1;i<=30;++i) {
			System.out.print("第"+i+"次:");
			p0=tool.reportModel(features, trainingSet, false);
			pp+=p0;
			System.out.println(p0);
		}
		System.out.println("30次均值:"+(pp/30));
	}
}

节点类:TreeNode.java

package base;

import java.util.Map;

public class TreeNode {
	private String element;
	private Map<String, TreeNode>  children;
	
	public TreeNode() {
		// TODO 自动生成的构造函数存根
	}
	
	public TreeNode(String e,Map<String, TreeNode>  c) {
		// TODO 自动生成的构造函数存根
		element=e;
		children=c;
	}
	
	public Map<String, TreeNode> getChildren() {
		return children;
	}
	
	public String getElement() {
		return element;
	}
}

工具类:Matrix_2D.java

package base;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public class Matrix_2D<T> {
	ArrayList<ArrayList<T>> data;
	
	public Matrix_2D() {
		// TODO 自动生成的构造函数存根
		data=new ArrayList<ArrayList<T>>();
	}
	
	public Matrix_2D(ArrayList<ArrayList<T>> d) {
		data=new ArrayList<ArrayList<T>>();
		for(ArrayList<T> val : d)
			this.putLine(val);
	}

	public void putLine(ArrayList<T> line) {
		ArrayList<T> tmp=new ArrayList<T>();
		for(T t : line) tmp.add(t);
		data.add(tmp);
	}
	
	public int getRowDimension() {
		return data.size();
	}
	
	public int getColDimension() {
		return data.get(0).size();
	}
	
	public ArrayList<T> get(int i) {
		return data.get(i);
	}
	
	public T get(int i,int j) {
		return data.get(i).get(j);
	}
	
	public T remove(int i,int j) {
		return data.get(i).remove(j);
	}
	
	public ArrayList<T> remove(int index) {
		return data.remove(index);
	}
	
	public static String[] subArray(String[] original,String str) {
		String[] subArray=new String[original.length-1];
		int k=0;
		for(String s : original) {
			if(!s.equals(str)) subArray[k++]=s;
		}
		return subArray;
	}
	
	public static ArrayList<String> copyArrayList(ArrayList<String> data) {
		ArrayList<String> d=new ArrayList<String>();
		for(String s : data) d.add(s);
		return d;
	}
	
	public static String majority(ArrayList<String> labels) {
		Map<String, Integer> labelCount=new HashMap<String, Integer>();
		for(String s : labels) {
			if(!labelCount.containsKey(s)) labelCount.put(s,0);
			labelCount.put(s,labelCount.get(s)+1);
		}
		int count=-1;
		String t="";
		for(String s : labelCount.keySet()) {
			if(labelCount.get(s)>count) {
				count=labelCount.get(s);
				t=s;
			}
		}
		return t;
	}
}

参考文章:
机器学习算法:18大数据挖掘的经典算法以及代码Java实现
ID3、C4.5算法介绍以及java代码实现
归纳决策树ID3的实现

发布了7 篇原创文章 · 获赞 0 · 访问量 594

猜你喜欢

转载自blog.csdn.net/qq_34262612/article/details/104104532