SpringBoot 仿抖音短小程序开发 全栈式实战项目2018

Spark基于NaiveBayes朴素贝叶斯算法实现中文垃圾邮件分类实战(Java版 / Scala版)
网上很少能找到Spark millib系列算法对纯中文垃圾邮件分类的demo,此Demo做了Java + Scala的混合调用,训练数据做了Java/Scala两个版本的。
如有问题请私信交流
主要分为以下几个过程:
一、数据集下载
数据集来源于网络,具体地址我忘记了,所以分享在网盘,自行下载
链接:https://pan.baidu.com/s/1n0Xp0MIcL7C8SPgZFmdKLw
提取码:327b
二、特征抽取与建模训练模型
这里应该是两个步骤,为了方便程序的方便调用我就写一个demo里面了
特征抽取主要使用了庖丁分词,然后自己写了一个映射去生成特征值列表
如果对庖丁分词不了解的请自行了解一下,我这里是直接封装调用了

训练数据为Java/Scala两个版本,二选一即可,调用类调用的是Scala版本,使用Java版略作修改即可

Scala版

SpamTrainScala.scala
里面会用到我封装的部分util工具,我会在文章底部贴出代码
以下为代码部分:
package top.it1002.spark.ml


import java.io.{File, PrintWriter}
import java.util

import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}
import top.it1002.util.{DoFile, MessageLog, PaoDingCut}

/**
  * @Author       王磊
  * @Date         2018/12/14
  * @Description  Scala版数据特征抽取与模型训练
  */

object SpamTrainScala {
  def main(args: Array[String]): Unit = {
    // 创建spark配置对象
    val conf = new SparkConf().setAppName("spam").setMaster("local[*]")
    // 创建上下文对象
    val context = new SparkContext(conf)
    // 获取垃圾邮件与非垃圾邮件的原数据列表
    val spamContentList =  DoFile.getFileContentList("C:\\Users\\asus\\Desktop\\data\\email\\spam_test")
    val hamContentList =  DoFile.getFileContentList("C:\\Users\\asus\\Desktop\\data\\email\\ham_test")
    // 将源数据切割分词形成邮件分词列表(切割针对每一条邮件形成分词列表)
    val spamCutList = cutListString(spamContentList)
    val hamCutList = cutListString(hamContentList)
    // 通过邮件数据和分词列表获取特征关键词列表
    // 这里关键词为每一封邮件top15,然后再统一top100
    // 所以训练源数据每一封邮件分词不得少于15
    // 总邮件数 * 15 不得少于100(这里的总邮件特指垃圾邮件/非垃圾邮件,两者之间在获取时候没有交集)
    val spamKeyWord = getKeyWords(spamContentList,spamCutList,context)
    val hamKeyWord = getKeyWords(hamContentList,hamCutList,context)
    // 汇总垃圾邮件和非垃圾邮件关键词
    val keyWord = Array.concat(spamKeyWord,hamKeyWord)
    MessageLog.getConsoleLog("Info","特征汇总完毕!正在对特征值进行持久化...")
    // 将关键词本地持久化,留以使用训练模型测试数据时候使用
    val file = new PrintWriter(new File("C:\\Users\\asus\\Desktop\\data\\email\\keys\\spam.txt"))
    for(x <- 0 until keyWord.length){
      file.write(keyWord(x).replace(" ","") + ",")
    }
    file.close()
    // 获取每一封邮件分词进行特征值映射,存在为0,不存在为1,特征值列表总长度为100 + 100 = 200 (垃圾邮件特征抽取TOP + 非垃圾邮件特征抽取TOP)
    // 结果为列表,类似[[0.0,0.0,1.0,0.0,1.0...],[0.0,0.0,0.0,1.0,1.0],[.....]]这样的数据结构
    val spamKeyMapList = mapKeyWord(spamCutList,keyWord)
    val hamKeyMapList = mapKeyWord(hamCutList,keyWord)

    MessageLog.getConsoleLog("Info","特征映射完毕!开始对构建训练数据...")
    // 对特征列表进行格式化对象,与准备Seq容器去装载向量标签对象
    var trainSeq = Seq[LabeledPoint]()
    // 垃圾邮件指定标签为1
    for(x <- 0 until spamKeyMapList.size()) {
      trainSeq = trainSeq :+ LabeledPoint(1.0, Vectors.dense(spamKeyMapList.get(x)))
    }
    // 非垃圾邮件指定标签为2
    for(x <- 0 until hamKeyMapList.size()) {
      trainSeq = trainSeq :+ LabeledPoint(0.0, Vectors.dense(hamKeyMapList.get(x)))
    }
    // 通过构建的容器生成RDD
    val trainRDD = context.parallelize(trainSeq)
    MessageLog.getConsoleLog("Info","构建完成!开始对特征数据进行训练...")
    // 训练数据生成模型
    val model = NaiveBayes.train(trainRDD)
    // 存储训练模型
    val modelDir = new File("C:\\Users\\asus\\Desktop\\data\\email\\model")
    val modelDirFileList = modelDir.list()
    if(modelDirFileList.size == 0){
      MessageLog.getConsoleLog("Info","训练完成!正在持久化模型...")
      model.save(context,"C:\\Users\\asus\\Desktop\\data\\email\\model")
      MessageLog.getConsoleLog("Info","持久化完成!程序运行结束!")
    }
  }

  /**
    * 通过源数据列表与源数据的分词列表获取TOP100的热词列表
    * @param contentList
    * @param cutList
    * @param context
    * @return
    */
  def getKeyWords(contentList:util.ArrayList[String], cutList:util.ArrayList[String],context: SparkContext) = {
    // 空列表用于放置每一封邮件的TOP15热词元组(String,Int)
    val top15ArrList = new util.ArrayList[Array[Tuple2[String,Int]]]()
    MessageLog.getConsoleLog("Info","分词结束!开始进行词频统计...")
    // 以每一封邮件为基本单位对邮件抽取TOP15的热词
    for(i <- 0 until contentList.size()){
      val sourceRDD = context.parallelize[String](Seq(cutList.get(i)))
      val sortArray = sourceRDD.flatMap(_.replace("[","")
        .replace("]","")
        .split(","))
        .map((_,1))
        .reduceByKey(_ + _)
        .map(t => (t._2,t._1))
        .sortByKey(false)
        .map(t => (t._2,t._1))
        .take(15)
      top15ArrList.add(sortArray)
    }

    // 实例化元组容器
    var seq = Seq[Tuple2[String,Int]]()
    MessageLog.getConsoleLog("Info","词频统计结束!开始对所有邮件特征值再次进行汇总聚合...")
    // 挨个取出每一封邮件的TOP15,装载进Seq大容器
    for(x <- 0 until top15ArrList.size()){
      val arr = top15ArrList.get(x)
      for(y <- 0 until arr.size){
        seq = seq :+ (arr(y)._1,arr(y)._2.toInt)
      }
    }
    // 将元组容器转化为RDD
    val allRDD = context.parallelize[Tuple2[String,Int]](seq)
    // 对所有TOP进行降序排序
    val resAllRDD = allRDD.reduceByKey(_ + _).map(
      t =>
        (t._2,t._1)).sortByKey(false)
      .map(t => (t._2,t._1))
    // 获取TOP100数据,并且只取值Key,丢弃Value
    val keyWord = resAllRDD.map(_._1).take(100)
    // 返回TOP100热词列表
    keyWord
  }

  /**
    * 通过源数据列表进行分词,获取分词列表
    * @param contentList
    * @return
    */
  def cutListString(contentList:util.ArrayList[String]) = {
    val cutList = new util.ArrayList[String]()
    // 分词操作
    MessageLog.getConsoleLog("Info","开始进行分词...")
    for(i <- 0 until contentList.size()){
      cutList.add(PaoDingCut.cutString(contentList.get(i)).toString)
    }
    cutList
  }

  /**
    * 通过比对特征TOP200热词对每一封email分词列表进行特征值映射(存在为1.0,不存在为0.0)
    * @param cutList
    * @param keyWord
    * @return
    */
  def mapKeyWord(cutList: util.ArrayList[String],keyWord:Array[String]) = {
    val keyLists = new util.ArrayList[Array[Double]]()
    MessageLog.getConsoleLog("Info","特征值持久化完毕!开始对垃圾邮件特征值数据进行映射...")
    for(x <- 0 until cutList.size()){
      val mapTable = new Array[Double](200)
      val emailWordArr = cutList.get(x).replace("[","").replace("]","").split(",")
      for(y <- 0 until 200){
        if(emailWordArr.contains(keyWord(y))){
          mapTable(y) = 1.0
        }else{
          mapTable(y) = 0.0
        }
      }
      keyLists.add(mapTable)
    }
    keyLists
  }
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
Java版

SpamTrainJava.java
以下为代码部分:
package top.it1002.spark.ml;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;
import top.it1002.util.DoFile;
import top.it1002.util.PaoDingCut;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;

/**
 * @Author      王磊
 * @Date        2018/12/16
 * @ClassName   SpamTrainJava
 * @Description Scala版数据特征抽取与模型训练
 **/
public class SpamTrainJava {
    public static void main(String[] args) throws Exception {
        SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("spam");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        // 获取源文件数据列表
        ArrayList<String> spamSourceDataList = DoFile.getFileContentList("C:\\Users\\asus\\Desktop\\data\\email\\spam_test");
        ArrayList<String> hamSourceDataList = DoFile.getFileContentList("C:\\Users\\asus\\Desktop\\data\\email\\ham_test");
        // 通过源数据列表进行分词,获取分词列表
        ArrayList<ArrayList<String>> spamCutList= cutString(spamSourceDataList);
        ArrayList<ArrayList<String>> hamCutList= cutString(hamSourceDataList);
        // 对分词列表进行词频统计获取TOP100数据
        List<String> spamTop100 = getTop100(spamCutList, jsc);
        List<String> hamTop100 = getTop100(hamCutList, jsc);
        // 合并TOP100数据得到特征数据TOP200
        ArrayList<String> allKeyWord = new ArrayList<String>();
        String keyWordStr = "";
        for(String word:spamTop100){
            allKeyWord.add(word);
            keyWordStr += word + ",";
        }
        for(String word:hamTop100){
            allKeyWord.add(word);
            keyWordStr += word + ",";
        }
        // 对TOP200数据进行本地持久化
        PrintWriter pw = new PrintWriter("C:\\Users\\asus\\Desktop\\data\\email\\keys\\spam_java.txt");
        pw.write(keyWordStr);
        pw.close();
        // 对源数据列表进行特征匹配生成训练数据
        ArrayList<LabeledPoint> spamLabeledPointList = getTrainData(spamCutList, allKeyWord, 1.0);
        ArrayList<LabeledPoint> hamLabeledPointList  = getTrainData(hamCutList, allKeyWord, 0.0);
        spamLabeledPointList.addAll(hamLabeledPointList);
        // 将训练数据转化为rdd
        RDD<LabeledPoint> trainRDD = jsc.parallelize(spamLabeledPointList).rdd();
        // 数据训练生成模型
        NaiveBayesModel model = NaiveBayes.train(trainRDD);
        // 持久化模型
        File file = new File("C:\\Users\\asus\\Desktop\\data\\email\\model\\model_java");
        if(file.list().length == 0){
            model.save(jsc.sc(),"C:\\Users\\asus\\Desktop\\data\\email\\model\\model_java");
        }

    }

    /**
     * 将源数据进行分词,返回分词列表
     * @param list
     * @return
     */
    public static ArrayList<ArrayList<String>> cutString(ArrayList<String> list){
        // 准备容器装载分词列表
        ArrayList<ArrayList<String>> res = new ArrayList<ArrayList<String>>();
        for(String s:list){
            res.add(PaoDingCut.cutString(s));
        }
        return res;
    }

    /**
     * 获取任意类别词频排名前100的关键词
     * @param cutList
     * @param jsc
     * @return
     */
    public static List<String> getTop100(ArrayList<ArrayList<String>> cutList, JavaSparkContext jsc){
        // 获取TOP15
        ArrayList<List<Tuple2<String, Integer>>> top15List = new ArrayList<List<Tuple2<String, Integer>>>();
        for(ArrayList<String> s:cutList){
            // 将email分词列表转化为javaRDD
            JavaRDD<String>  emailJRDD = jsc.parallelize(s);
            // 映射为元组
            // 聚合
            // 降序排列
            // 获取top15热词列表
            List<Tuple2<String, Integer>> metaList = emailJRDD.mapToPair(new PairFunction<String, String, Integer>() {
                public Tuple2<String, Integer> call(String s) throws Exception {
                    return new Tuple2<String, Integer>(s, 1);
                }
            }).reduceByKey(new Function2<Integer, Integer, Integer>() {
                public Integer call(Integer v1, Integer v2) throws Exception {
                    return v1 + v2;
                }
            }).mapToPair(new PairFunction<Tuple2<String,Integer>, Integer, String>() {
                @Override
                public Tuple2<Integer, String> call(Tuple2<String, Integer> t) throws Exception {
                    return new Tuple2<Integer, String>(t._2(),t._1());
                }
            }).sortByKey(false).mapToPair(new PairFunction<Tuple2<Integer,String>, String, Integer>() {
                public Tuple2<String, Integer> call(Tuple2<Integer, String> t) throws Exception {
                    return new Tuple2<String, Integer>(t._2(),t._1());
                }
            }).take(15);
            top15List.add(metaList);
        }
        // 汇总top15为top100
        ArrayList<Tuple2<String,Integer>> allTuple = new ArrayList<Tuple2<String, Integer>>();
        for(List<Tuple2<String, Integer>> list1:top15List){
            allTuple.addAll(list1);
        }
        // 将所有元组列表转化为JavaRDD
        JavaRDD<Tuple2<String, Integer>> allJRDD = jsc.parallelize(allTuple);
        // 映射
        // 聚合
        // 降序
        // 获取TOP100热词列表
        List<String> top100List = allJRDD.mapToPair(new PairFunction<Tuple2<String,Integer>, String, Integer>() {
            public Tuple2<String, Integer> call(Tuple2<String, Integer> t3) throws Exception {
                return new Tuple2<String,Integer>(t3._1(),t3._2());
            }
        }).reduceByKey(new Function2<Integer, Integer, Integer>() {
            public Integer call(Integer v1, Integer v2) throws Exception {
                return v1 + v2;
            }
        }).mapToPair(new PairFunction<Tuple2<String,Integer>, Integer, String>() {
            public Tuple2<Integer, String> call(Tuple2<String, Integer> t4) throws Exception {
                return new Tuple2<Integer, String>(t4._2(),t4._1());
            }
        }).sortByKey(false).map(new Function<Tuple2<Integer,String>, String>() {
            public String call(Tuple2<Integer, String> v1) throws Exception {
                return v1._2();
            }
        }).take(100);
        return top100List;
    }

    /**
     * 通过分词列表与关键词列表获取训练映射数据
     * @param cutList
     * @param keyWords
     * @param type
     * @return
     */
    public static ArrayList<LabeledPoint> getTrainData(ArrayList<ArrayList<String>> cutList, ArrayList<String> keyWords, double type){
        // 初始化向量标签容器
        ArrayList<LabeledPoint> featureBox = new ArrayList<LabeledPoint>();
        for(ArrayList<String> emailMeta:cutList){
            // 对比映射转化
            double[] mapTrain = new double[200];
            int mapIndex = 0;
            for(String key:keyWords){
                if(emailMeta.contains(key)){
                    mapTrain[mapIndex] = 1.0;
                }
                mapIndex ++;
            }
            featureBox.add(new LabeledPoint(type, Vectors.dense(mapTrain)));
        }
        return featureBox;
    }

}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
三、SDK封装调用测试类
SpamCheck.java
以下为代码部分
package top.it1002.spark.ml;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;

import top.it1002.util.DoFile;
import top.it1002.util.PaoDingCut;

import java.io.File;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * @Author      王磊
 * @Date        2018/12/14
 * @ClassName   SpamCheck
 * @Description 通过已经训练好的模型进行邮件分类SDK
 **/
public class SpamCheck {
    public static void textCheck(String str) {
        // 创建spark配置对象
        SparkConf conf = new SparkConf().setAppName("spam").setMaster("local[*]");
        // 创建JavaSpark上下文对象
        JavaSparkContext jsc = new JavaSparkContext(conf);
        // 获取训练数据时候持久化的热词TOP200字符串数据
        ArrayList<String> keyString = DoFile.getFileContentList("C:\\Users\\asus\\Desktop\\data\\email\\keys");
        // 对TOP进行转化为列表
        List<String> keyList = Arrays.asList(keyString.get(0).split(","));
        // 转化为列表,方便函数的调用
        ArrayList<String> strList = new ArrayList<String>();
        strList.add(str);
        // 通过源数据列表与TOP200列表数据比对生成特征值列表
        ArrayList<double[]> strNumList = getCharacter(strList,keyList);
        Vector testVec = Vectors.dense(strNumList.get(0));
        // 加载持久化的训练模型
        NaiveBayesModel model = NaiveBayesModel.load(jsc.sc(),"C:\\Users\\asus\\Desktop\\data\\email\\model");
        // 通过模型对每一封email特征值列表进行预测
        Double predictNum = model.predict(testVec);
        System.out.println("预测内容:" + str);
        System.out.println("预测值:" + predictNum);
        if(predictNum == 1.0){
            System.out.println("预测结果:该邮件为垃圾邮件!");
        }else {
            System.out.println("预测结果:该邮件为正常邮件!");
        }
    }


    /**
     * 通过对已知邮件列表的文件数据进行测试
     * @param spamPath
     * @param hamPath
     */
    public static void fileCheck(String spamPath,String hamPath) {
        // 创建spark配置对象
        SparkConf conf = new SparkConf().setAppName("spam").setMaster("local[*]");
        // 创建JavaSpark上下文对象
        JavaSparkContext jsc = new JavaSparkContext(conf);
        // 获取源数据列表
        ArrayList<String> spamTestList = DoFile.getFileContentList(spamPath);
        ArrayList<String> hamTestList = DoFile.getFileContentList(hamPath);
        // 获取源数据文件名列表
        ArrayList<String> nameList = new ArrayList<String>();
        String[] file1 = new File(spamPath).list();
        String[] file2 = new File(hamPath).list();
        String[] fileNameList = new String[file1.length + file2.length];
        System.arraycopy(file1, 0, fileNameList, 0, file1.length);
        System.arraycopy(file2, 0, fileNameList, file1.length, file2.length);
        // 获取训练数据时候持久化的热词TOP200字符串数据
        ArrayList<String> keyString = DoFile.getFileContentList("C:\\Users\\asus\\Desktop\\data\\email\\keys");
        // 对TOP进行转化为列表
        List<String> keyList = Arrays.asList(keyString.get(0).split(","));
        // 通过源数据列表与TOP200列表数据比对生成特征值列表
        ArrayList<double[]> spamNumList = getCharacter(spamTestList,keyList);
        ArrayList<double[]> hamNumList = getCharacter(hamTestList,keyList);
        // 创建LabeledPoint列表装载所有的测试数据单元
        List<LabeledPoint> lab = new ArrayList<LabeledPoint>();
        // 垃圾邮件装载
        for(double[] arr:spamNumList){
            lab.add(new LabeledPoint(1.0, Vectors.dense(arr)));
        }
        // 非垃圾邮件装载
        for(double[] arr:hamNumList){
            lab.add(new LabeledPoint(0.0, Vectors.dense(arr)));
        }
        // 将LabeledPoint转化为RDD
        RDD<LabeledPoint> testRDD = jsc.parallelize(lab).rdd();
        // 加载持久化的训练模型
        NaiveBayesModel model = NaiveBayesModel.load(jsc.sc(),"C:\\Users\\asus\\Desktop\\data\\email\\model");
        // 通过模型对每一封email特征值列表进行预测
        int index = 0;
        int wrong = 0;
        for(LabeledPoint labeledPoint:lab){
            Double predictNum = model.predict(labeledPoint.features());
            System.out.println("预测文件名称为:" + fileNameList[index]);
            System.out.println("准确值:" + labeledPoint.label() + ",预测值:" + predictNum);
            String type = "";
            if(labeledPoint.label() == 0.0){
                type = "正常邮件";
            }else type = "垃圾邮件";
            if(labeledPoint.label() == predictNum){
                System.out.println("该邮件为" + type  + ",预测准确!");
            }else{
                System.out.println("该邮件为" + type  + ",预测错误!");
                wrong++;
            }
            index++;
        }
        DecimalFormat df = new DecimalFormat("0.00");
        df.setRoundingMode(RoundingMode.HALF_UP);
        System.out.println("======================");
        System.out.println("预测完毕!一共预测" + lab.size() + "封邮件!\r\n" +
                "预测准确:" + (lab.size() - wrong) + "封邮件!\r\n" +
                "预测错误:" + wrong + "封邮件!\r\n" +
                "预测准确率为:" + df.format((double)(lab.size() - wrong) / lab.size()) + "!\r\n"  +
                "预测错误率为:" + df.format((double)wrong / lab.size()) + "!"
        );
        System.out.println("======================");
    }


    /**
     * 对未知目录下邮件进行预测
     * @param sourcePath
     */
    public static void fileCheck(String sourcePath) {
        // 创建spark配置对象
        SparkConf conf = new SparkConf().setAppName("spam").setMaster("local[*]");
        // 创建JavaSpark上下文对象
        JavaSparkContext jsc = new JavaSparkContext(conf);
        // 获取源数据列表
        ArrayList<String> TestList = DoFile.getFileContentList(sourcePath);
        // 获取源数据文件名列表
        ArrayList<String> nameList = new ArrayList<String>();
        String[] fileNameList = new File(sourcePath).list();
        // 获取训练数据时候持久化的热词TOP200字符串数据
        ArrayList<String> keyString = DoFile.getFileContentList("C:\\Users\\asus\\Desktop\\data\\email\\keys");
        // 对TOP进行转化为列表
        List<String> keyList = Arrays.asList(keyString.get(0).split(","));
        // 通过源数据列表与TOP200列表数据比对生成特征值列表
        ArrayList<double[]> testNumList = getCharacter(TestList,keyList);

        // 加载持久化的训练模型
        NaiveBayesModel model = NaiveBayesModel.load(jsc.sc(),"C:\\Users\\asus\\Desktop\\data\\email\\model");
        // 通过模型对每一封email特征值列表进行预测
        int index = 0;
        for(double[] testNum:testNumList){
            Vector testVec = Vectors.dense(testNum);
            Double predictNum = model.predict(testVec);
            String type = "";
            if(predictNum == 0.0){
                type = "正常邮件";
            }else type = "垃圾邮件";
            System.out.println("预测文件名称为:" + fileNameList[index] + ",该邮件为:" + type + "!\r\n");
            index ++;
        }
    }


    /**
     * 通过源数据列表与TOP200列表数据比对生成特征值列表
     * @param testList
     * @param keyList
     * @return
     */
    public static ArrayList<double[]> getCharacter(ArrayList<String> testList,List<String> keyList){
        // 实例化装载容器
        ArrayList<double[]> characterNumList = new ArrayList<double[]>();
        // 对每一条源数据进行分词比对,生成特征值数组,然后装载进列表容器
        for(int i = 0 ; i < testList.size() ; i++){
            double[] keyNum = new double[200];
            ArrayList<String> cutList = PaoDingCut.cutString(testList.get(i));
            for(String s:cutList){
                if(keyList.contains(s)){
                    int index = keyList.indexOf(s);
                    keyNum[index] = 1.0;
                }
            }
            characterNumList.add(keyNum);
        }
        return characterNumList;
    }
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
四、app调用类
SpamApp.java
以下为代码部分
package top.it1002.spark.ml;

/**
 * @Author      王磊
 * @Date        2018/12/15
 * @ClassName   SpamApp
 * @Description 邮件分类APP
 **/
public class SpamApp {
    public static void main(String[] args) {
        // 文本测试(测试1、2不可同时开,一个JVM只能开一个SparkContext)
        /*
        // 文本测试1
        String s1 = "贵公司负责人(经理/财务)您好: 深圳市海华公司受多家公司委托向外低点代开部分增值税电脑发票(7%左右)和普通商品销售税 发票。(国税、地税运输、广告、服务等票2%左右)还可以根据所做数量额度的大小来商讨优惠的点数! 本公司郑重承诺所用绝对是真票!可验证后付款! 此信息长期有效,如须进一步洽商: 请电:13480872676 联系人:郑亦文 顺祝商祺! 低点代开发票! ";
        SpamCheck.textCheck(s1);
        //*/
        /*
        // 文本测试2
        String s2 = "现在的孩子懂事都早 我小外甥才两岁多的时候 我姐和我姐夫吵架,说我姐夫不干家务 我小外甥都知道说:妈妈,等我长高了我帮你洗衣服帮你做饭 在火车上我姐拜托旁边的人下车的时候帮她们拿下东西 等到下车的时候小家伙就会拉着人家说:叔叔,你一会儿要帮我妈妈拿东西哦 不知道的还以为是大人教过了 所以教育要趁早。。。 其实5岁的孩子已经啥都懂了 能观颜查色 偶lg姐姐家的孩子也是 ";
        SpamCheck.textCheck(s2);
        //*/

        // 已知类型文件预测模式
        /*
        SpamCheck.fileCheck("C:\\Users\\asus\\Desktop\\data\\email\\spams","C:\\Users\\asus\\Desktop\\data\\email\\hams");
        //*/

        // 未知类型预测模式
        /*
        SpamCheck.fileCheck("C:\\Users\\asus\\Desktop\\data\\email\\fileCheck");
        //*/
    }
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
五、工具类
DoFile.java
以下为代码部分
package top.it1002.util;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.util.ArrayList;

/**
 * @Author      王磊
 * @Date        2018/12/13
 * @ClassName   DoFile
 * @Description 文件工具类
 **/
public class DoFile {
    public static ArrayList<String> getFileContentList(String path) {
        File file = new File(path);
        File[] files = file.listFiles();
        ArrayList<String> resList = new ArrayList<String>();
        for(File f:files){
            try {
                FileInputStream fis = new FileInputStream(f.getPath());
                String s = "";
                byte[] buff = new byte[1024];
                int len = 0;
                while((len = fis.read(buff)) != -1){
                    s += new String(buff, 0, len);
                }
                resList.add(s);
            } catch (Exception e) {
                e.printStackTrace();
            }

        }
        return resList;

    }
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
MessageLog.java
以下为代码部分
package top.it1002.util;

import java.util.Date;

/**
 * @Author      王磊
 * @Date        2018/12/14
 * @ClassName   MessageLog
 * @Description 控制台提示信息打印类
 **/
public class MessageLog {
    public static void getConsoleLog(String type,String msg) {
        System.out.println(type + " " +getNowTime() + " " + msg);
    }

    public static String getNowTime() {
        long ts = System.currentTimeMillis();
        Date date = new Date(ts);
        return date.toString();
    }
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
PaoDingCut.java
以下为代码部分
package top.it1002.util;

import net.paoding.analysis.analyzer.PaodingAnalyzer;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.Token;
import org.apache.lucene.analysis.TokenStream;

import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;

/**
 * @Author      王磊
 * @Date        2018/12/13
 * @ClassName   PaoDingCut
 * @Description 庖丁分词类
 **/
public class PaoDingCut {
    public static ArrayList<String> cutString(String text){
        // 定义一个解析器
        Analyzer analyzer = new PaodingAnalyzer();

        // 得到token序列的输出流
        TokenStream tokens = analyzer.tokenStream(text, new StringReader(text));
        // 定义返回结果列表
        ArrayList<String> tokenList = new ArrayList<String>();
        try{
            Token t;
            while((t=tokens.next() ) !=null){
                // 限制长度为1的不处理
                if(t.termText().length() > 1){
                    tokenList.add(t.termText());
                }
            }
        }catch(IOException e){
            e.printStackTrace();
        }
        return tokenList;
    }
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
大致思路就是这样了
后续会更新java/scala版本相关小例子


--------------------- 
作者:王磊呀 
来源:CSDN 
原文:https://blog.csdn.net/qq_41287993/article/details/85013378 
版权声明:本文为博主原创文章,转载请附上博文链接!

猜你喜欢

转载自blog.csdn.net/weixin_44226873/article/details/85534189