随机森林 RandomForest java

        随机森林是由多棵树组成的分类或回归方法。主要思想来源于Bagging算法,Bagging技术思想主要是给定一弱分类器及训练集,让该学习算法训练多轮,每轮的训练集由原始训练集中有放回的随机抽取,大小一般跟原始训练集相当,这样依次训练多个弱分类器,最终的分类由这些弱分类器组合,对于分类问题一般采用多数投票法,对于回归问题一般采用简单平均法。随机森林在bagging的基础上,每个弱分类器都是决策树,决策树的生成过程中中,在属性的选择上增加了依一定概率选择属性,在这些属性中选择最佳属性及分割点,传统做法一般是全部属性中去选择最佳属性,这样随机森林有了样本选择的随机性,属性选择的随机性,这样一来增加了每个分类器的差异性、不稳定性及一定程度上避免每个分类器的过拟合(一般决策树有过拟合现象),由此组合分类器增加了最终的泛化能力。下面是代码的简单实现

/**
 * 随机森林  回归问题
 * @author ysh   1208706282
 *
 */
public class RandomForest {
    List<Sample> mSamples;
    List<Cart> mCarts;
    double mFeatureRate;
    int mMaxDepth;
    int mMinLeaf;
    Random mRandom;
    /**
     * 加载数据   回归树
     * @param path
     * @param regex
     * @throws Exception
     */
    public  void loadData(String path,String regex) throws Exception{
        mSamples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(path));
        String line = null;
        String splits[] = null;
        Sample sample = null;
        while(null != (line=reader.readLine())){
            splits = line.split(regex);
            sample = new Sample();
            sample.label = Double.valueOf(splits[0]);
            sample.feature = new ArrayList<Double>(splits.length-1);
            for(int i=0;i<splits.length-1;i++){
                sample.feature.add(new Double(splits[i+1]));
            }
            mSamples.add(sample);
        }
        reader.close();
    }
    public void train(int iters){
        mCarts = new ArrayList<Cart>(iters);
        Cart cart = null;
        for(int iter=0;iter<iters;iter++){
            cart = new Cart();
            cart.mFeatureRate = mFeatureRate;
            cart.mMaxDepth = mMaxDepth;
            cart.mMinLeaf = mMinLeaf;
            cart.mRandom = mRandom;
            List<Sample> s = new ArrayList<Sample>(mSamples.size());
            for(int i=0;i<mSamples.size();i++){
                s.add(mSamples.get(cart.mRandom.nextInt(mSamples.size())));
            }
            cart.setData(s);
            cart.train();
            mCarts.add(cart);
            System.out.println("iter: "+iter);
            s = null;
        }
    }
    /**
     * 回归问题简单平均法  分类问题多数投票法
     * @param sample
     * @return
     */
    public double classify(Sample sample){
        double val = 0;
        for(Cart cart:mCarts){
            val += cart.classify(sample);
        }
        return val/mCarts.size();
    }
    /**
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        // TODO Auto-generated method stub
        RandomForest forest = new RandomForest();
        forest.loadData("F:/2016-contest/20161001/train_data_1.csv", ",");
        forest.mFeatureRate = 0.8;
        forest.mMaxDepth = 3;
        forest.mMinLeaf = 1;
        forest.mRandom = new Random();
        forest.mRandom.setSeed(100);
        forest.train(100);
        
        List<Sample> samples = Cart.loadTestData("F:/2016-contest/20161001/valid_data_1.csv", true, ",");
        double sum = 0;
        for(Sample s:samples){
            double val = forest.classify(s);
            sum += (val-s.label)*(val-s.label);
            System.out.println(val+"  "+s.label);
        }
        System.out.println(sum/samples.size()+"  "+sum);
        System.out.println(System.currentTimeMillis());
    }

}


猜你喜欢

转载自blog.csdn.net/ysh126/article/details/53125858