MLlib spark 垃圾邮件分类

pom依赖

<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-core_2.12</artifactId>
    <version>2.4.0</version>
</dependency>

<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-streaming_2.12</artifactId>
    <version>2.4.0</version>
</dependency>
<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-mllib_2.12</artifactId>
    <version>2.4.0</version>
</dependency>
<dependency>
    <groupId>com.thoughtworks.paranamer</groupId>
    <artifactId>paranamer</artifactId>
    <version>2.8</version>
</dependency>

代码

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.feature.HashingTF;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;

import java.util.Arrays;
import java.util.Iterator;

/**
 * 逻辑回归算法
 */
public class Regression {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("asd").setMaster("local[2]");
        JavaSparkContext sc = new JavaSparkContext(conf);
        sc.setLogLevel("ERROR");
        JavaRDD<String> data1 = sc.textFile("C:\\workspace\\sparkTest\\src\\test\\java\\data\\1.txt");
        JavaRDD<String> data2 = sc.textFile("C:\\workspace\\sparkTest\\src\\test\\java\\data\\2.txt");
        JavaRDD<String> javaRDD1 = data1.flatMap(new FlatMapFunction<String, String>() {
            @Override
            public Iterator<String> call(String s) throws Exception {
                String[] split = s.split(" ");
                return Arrays.asList(split).iterator();
            }
        });
        JavaRDD<String> javaRDD2 = data2.flatMap(new FlatMapFunction<String, String>() {
            @Override
            public Iterator<String> call(String s) throws Exception {
                String[] split = s.split(" ");
                return Arrays.asList(split).iterator();
            }
        });

        //创建一个示例,将文本映射成多个特征向量
        final HashingTF tf = new HashingTF(10);

        //创建数据集分别存放数据
        JavaRDD<LabeledPoint> map1 = javaRDD1.map(new Function<String, LabeledPoint>() {//类别1
            @Override
            public LabeledPoint call(String s) throws Exception {
                return new LabeledPoint(0, tf.transform(Arrays.asList(s.split(" "))));//设置标签
            }
        });
        JavaRDD<LabeledPoint> map2 = javaRDD2.map(new Function<String, LabeledPoint>() {//类别2
            @Override
            public LabeledPoint call(String s) throws Exception {
                return new LabeledPoint(1, tf.transform(Arrays.asList(s.split(" "))));//设置标签
            }
        });
        //合并结果并缓存
        JavaRDD<LabeledPoint> trainData = map1.union(map2);
        trainData.cache();//使用迭代算法需要缓存

        //使用SGD算法进行逻辑回归计算得到模型
        LogisticRegressionModel model = new LogisticRegressionWithSGD().run(trainData.rdd());

        //测试集
        String t = "taltic ddd good gdds democracy fss care unassuming asdas mentality asdas";
        String f = "good good asda sdasdas sada asd sad asd a";

        //转化向量
        Vector transformT = tf.transform(Arrays.asList(t.split(" ")));
        Vector transformF = tf.transform(Arrays.asList(f.split(" ")));

        //测试结果
        System.out.println(t + " 结果为: " + model.predict(transformT));
        System.out.println(f + " 结果为: " + model.predict(transformF));
        trainData.unpersist();
        sc.close();
    }
}

猜你喜欢

转载自blog.csdn.net/weixin_42660202/article/details/88424812