每日一题 为了工作 2020 0504 第六十二题

package data.bjsj.fjjb;


import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;

import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;


/**
 *
 * @author 雪瞳
 * @Slogan 时钟尚且前行,人怎能就此止步!
 * @Function 
 *
 */
public class LogisticModel {
    public static void main(String[] args) {
        
        SparkSession session = SparkSession.builder().appName("logistic").master("local").getOrCreate();
        JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext());
        SparkContext sc = JavaSparkContext.toSparkContext(jsc);

        jsc.setLogLevel("Error");
        JavaRDD<String> fileRDD = jsc.textFile("./save/rootData");
        JavaRDD<LabeledPoint> labeledPointJavaRDD = fileRDD.map(new Function<String, LabeledPoint>() {
            //"2015-11-01 20:20:16"	1.85999330468501	1.22359452534749	2.51578969727773	-0.403918740333512	0.0149184125297424		0
            @Override
            public LabeledPoint call(String line) throws Exception {
                String[] splits = line.split("\t");
                String label = splits[splits.length - 1];

                double[] wd = new double[splits.length - 3];
                for (int i = 0; i < wd.length; i++) {
                    wd[i] = Double.parseDouble(splits[i+1]);
                }
                LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(label), Vectors.dense(wd));
                return labeledPoint;
            }
        });

        
        double[] doubles = new double[]{0.7,0.3};
        RDD<LabeledPoint> rdd = labeledPointJavaRDD.rdd();
        RDD<LabeledPoint>[] metaDataSource = rdd.randomSplit(doubles, 100L);
        
        RDD<LabeledPoint> traingData = metaDataSource[0];
        RDD<LabeledPoint> testData = metaDataSource[1];
        
        LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS();
        lr.setNumClasses(2);
        lr.setIntercept(true);
        LogisticRegressionModel model = lr.run(traingData);
        JavaRDD<Double> predictRdd = testData.toJavaRDD().map(new Function<LabeledPoint, Double>() {
            @Override
            public Double call(LabeledPoint labeledPoint) throws Exception {
                double predict = model.predict(labeledPoint.features());
                return predict;
            }
        });
        JavaPairRDD<Double, Double> zipRdd = predictRdd.zip(testData.toJavaRDD().map(new Function<LabeledPoint, Double>() {
            @Override
            public Double call(LabeledPoint labeledPoint) throws Exception {
                return labeledPoint.label();
            }
        }));

        Accumulator<Integer> accumulator = jsc.accumulator(0);
        zipRdd.foreach(new VoidFunction<Tuple2<Double, Double>>() {
            @Override
            public void call(Tuple2<Double, Double> tp) throws Exception {
                Double label = tp._2();
                Double predict = tp._1();
                if (Double.compare(label,predict)==0){
                    accumulator.add(1);
                }
            }
        });
        long count = zipRdd.count();
        Integer value = accumulator.value();
        System.err.println("总数目是:"+count);
        System.err.println("正确数目是:"+value);
        double rate = value / (double) count;
        System.err.println("正确率是:"+rate*100+"%");
        String path ="./save/model";
        double  stand = 80.00;
        if (Double.compare(rate,stand)<0){
            model.save(sc,path);
        }
    }
}

  

猜你喜欢

转载自www.cnblogs.com/walxt/p/12825682.html
今日推荐