简单介绍:
这里我举例来简单的说明下贝叶斯算法 :
如上图 : 假设一个班有 100 人 , 其中80%的是玩王者荣耀的 , 20%玩吃鸡 ,那么在所有人中 玩LOL的占 10%
其中 同时玩 王者 和 LOL 的人 有 8个人 ,占所有玩 王者的人中 的 8 / 80 ,同样的 同时玩 吃鸡和 LOL的占所有玩吃鸡人数的
2 / 20 根据这个概率表 , 我们可以计算出 当我们只知道玩 LOL的前提下 , 这个人可能玩 王者 , 或者 吃鸡的概率
如上图 : 其实简单的理解就是 用 同时 玩王者和LOL的人数 除以 所有玩 LOL的总人数 , 这样就可以计算出 ,当一个人玩LOL 的前提下 , 这个人还有可能玩王者的概率 , 同理可以计算出 当一个人玩LOL的前提下 还有可能玩 吃鸡的概率。
然后二者进行比较 , 那个概率大就说明更有可能玩哪种游戏 。 如上图,显然当知道一个人玩LOL的前提下 , 这个人玩王者的概率 显然要大于玩 吃鸡的概率 , 那么我们更倾向于 这个人可能玩 王者的概率较大。
这里我们可以的得出一个贝叶斯公式 :P(A|B) = (PB|A)* P(A) / P(B) 如下图
这里我在右边贴出了 计算 玩LOL的前提下 玩王者的概率 计算公式 , 进行对比 比较容易理解 。其实这个就是贝叶斯算法的简单理解 , 下面说说垃圾邮件预测的思路
垃圾邮件预测介绍
算法思路
这里我同样用举例的方式来介绍 ,那么假设 我们有 100封邮件 , 其中80封是正常邮件(占所有有邮件的80%)20封是正常邮件(占所有邮件的20%) , 然后 ,在所有邮件中 有 5 封是含有 fuck 这个单词的 (占所有邮件的5%)那么我们可以得出下面这个概率表
如上图 : 同理我们可以计算出 , 当一封邮件中出现 fuck这个单词的时候 , 这封邮件可能是垃圾邮件的概率是 80% , 有可能是非垃圾邮件的概率是 20% , 通过比较二者的概率 , 可以计算出 ,当一封邮件出现 fuck 单词时 , 这封邮件更有可能是垃圾单词的,当然 , 一封邮件中肯定不可能只有一个单词, 我们可以根据出现的单词在不同的邮件中出现的概率计算出当这些邮件中出现这些单词时 ,这封邮件可能是垃圾邮件的概率。
至于 如何计算每个单词在不同种类的概率 ,下面我举例说明
如上图 , 这里有四封邮件 , 我们可以对邮件内容进行切分, 得到切分后的所有不重复的单词 , 然后统计邮件中对应单词的数量,进而计算出单词在指定类型的邮件中出现的概率。 比如 单词 i ,在四封邮件中出现了三次, 那么我们可以认为 i 在这种类型的单词中出现的概率是 3 / 4 进而进行计算得出 一个单词概率表 ,用于计算。
数据来源
垃圾邮件预测的最重要的一环是 构建模型的数据来源 , 因为数据源直接影响了你最终预测的准确性, 我在网上找到的多是英文的邮件 , 中文的比较少 , 这个是我好不容易找到的 ,大概有14000+ 封邮件 , 其中一半是正常的邮件 , 一半是垃圾邮件 ,这里我直接给出百度云链接 , 如果大家有好的数据源 也可已在评论区分享下 , 感谢
百度网盘下载链接: https://pan.baidu.com/s/1Hsno4oREMROxWwcC_jYAOA
密码: qa49
这里有一点需要注意的是 ,这些邮件都是GBK格式的 ,如果你的项目使用的是UTF-8 ,那么有可能会乱码 , 这里我是使用的是转换流, 使用GBK格式读取 , 使用UTF-8 写出 , 然后就可以批量转化 为 UTF-8 格式 代码如下
package com.wangt.bayes.test;
import spark.utils.IOUtils;
import java.io.*;
/**
* 批量转换文件编码
*/
public class ConductData {
static InputStreamReader isr = null;
static OutputStreamWriter osr = null;
public static void main(String[] args) throws IOException {
File normal = new File("data/test/normal"); // 这里读取的是文件的目录
File[] normals = normal.listFiles();
ConductData cd = new ConductData();
for (int i = 0; i < normals.length; i++) {
cd.changeEcoding(normals[i]);
System.out.println("更改成功 +" + i);
}
}
/**
* 转换文件的编码
* @param f
*/
public void changeEcoding(File f){
try {
isr = new InputStreamReader(new FileInputStream(f), "GBK");
File out = new File("mydata/test/ham/" + f.getName() + ".txt");
if (!out.exists()){
boolean res = out.createNewFile();
System.out.println("文件 : " +f.getName()+ (res ? "创建成功" : "创建失败"));
}
osr = new OutputStreamWriter(new FileOutputStream(out) , "UTF-8");
IOUtils.copy(isr , osr); // 这个是saprk自带的拷贝流
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
// 关闭流
if(isr != null){
try {
isr.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if(osr != null){
try {
osr.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
构建模型的流程
文件切分
这里我使用的是结巴分词器 ,它可以支持中文分词 ,代码如下
/**
* 对单个字符串进行切分 获取单词集合
*
* @param line 被切分的单词
* @return 返回存储被切分的单词的 集合
*/
public static List<String> splitWord(String line) {
// 创建结巴分词器
JiebaSegmenter segmenter = new JiebaSegmenter();
// 创建存储 分割后的单词的结婚
ArrayList<String> datas = new ArrayList<>();
// 分词
List<SegToken> list = segmenter.process(line, JiebaSegmenter.SegMode.SEARCH);
// 过滤掉空格
for (SegToken segToken : list) {
if (segToken.equals(" ")) {
continue;
}
datas.add(segToken.word);
}
return datas;
}
如果你不知道如何下载结巴分词器 , 下面我提供了maven的依赖
<!-- 结巴 分词-->
<!-- https://mvnrepository.com/artifact/com.huaban/jieba-analysis -->
<dependency>
<groupId>com.huaban</groupId>
<artifactId>jieba-analysis</artifactId>
<version>1.0.2</version>
</dependency>
构建词袋
下一步是获取所有不重复的单词 ,构建词袋 ,这里我使用的是 TreeMap , 因为Map的key 不能重复 ,正好可以存储单词 , value是默认为 0 , 为下一步计算单词词频做铺垫
/**
* 切分一个字符串 , 获取不重复的单词 , 并存储到一个 Map 中
* key 为 单词 , value 默认是 0
*
* @param lines 获取重复单词的数据
* @return 存放不重复单词的词袋
*/
public static TreeMap<String, Integer> getWordBag(String lines) {
// 分词
List<String> values = splitWord(lines);
// 获取不重复的单词
TreeMap<String, Integer> wordBag = new TreeMap<>();
for (String line : values) {
wordBag.put(line, 0);
}
// 返回值
return wordBag;
}
统计单词词频 , 并将词频转化为 double数组
/**
* 指定一个存储单词的词袋 , 切分指定字符串后得到单词 , 获取词袋中对应单词出现的次数组成double数组
*
* @param line 需要统计单词的语句
* @param wordBag 存放不重复单词的词袋
* @return 存放单词词频的double数组
*/
public static double[] getWordCount(String line, TreeMap<String, Integer> wordBag) {
// 创建存放单词词频的 对象
TreeMap<String, Integer> countWord = new TreeMap<>();
// 将词袋内的单词添加进 统计词频的对象中
countWord.putAll(wordBag);
// 对被统计的单词进行分词
List<String> wordDatas = splitWord(line);
// 统计单词词频
for (String wordData : wordDatas) {
// 如果被统计的单词在词袋中没有出现 , 将该单词丢弃
if (!countWord.containsKey(wordData)) {
continue;
}
countWord.replace(wordData, countWord.get(wordData) + 1);
}
// 将单词词频存放到 double 数组中
double[] values = new double[countWord.size()];
int size = 0;
Iterator<String> countKey = countWord.keySet().iterator();
while (countKey.hasNext()) {
String word = countKey.next();
values[size] = countWord.get(word);
size++;
}
// 返回存放单词词频的double数组
return values;
}
然后是 将 词频数组封装到 LabeledPoint 对象中 进行模型的构建
public static void main(String[] args) {
// 1.获取 SparkContext
JavaSparkContext sc = BayesUtils.init();
// 2.读取样本数据
JavaRDD<String> lines = sc.textFile("input/navie_bayes_data.txt");
List<String> words = lines.take((int) lines.count());
// 3.使用结巴分词器进行分词 获取不重复的单词
TreeMap<String , Integer> wordBag = BayesUtils.getWordBag(words);
// 4.对数据进行切分 将切分后的 (标签 , 内容) 封装到 LabeledPoint 对象中 ,并将所有的LabeledPoint存储到 RDD中
JavaRDD<LabeledPoint> parsedData = lines.map(new Function<String, LabeledPoint>() {
@Override
public LabeledPoint call(String v1) throws Exception {
// 对数据进行按行切分
String[] fields = v1.split(",");
double key = 0.0;
if(fields[0].equals("ham")){ // ham = 0 , spam = 1 这里特别注意 是fields[0] 不是fields
key = 0;
}else {
key = 1;
}
// 获取单词词频
double[] datas = BayesUtils.getWordCount(fields[1] , wordBag);
// 返回封装好的值
return new LabeledPoint(key ,new DenseVector(datas));
}
});
// 6.样本数据划分 训练样本 和测试样本
// 一个参数是 训练样本和测试的比例 , 另一个是随机数的种子
JavaRDD<LabeledPoint>[] splits = parsedData.randomSplit(new double[]{0.9 , 0.1} , 100L);
// 7.获取样本集
// 获取训练样本集
JavaRDD<LabeledPoint> training = splits[0];
// 获取测试样本集
JavaRDD<LabeledPoint> test = splits[1];
// 8.对数据进行训练 , 新建贝叶斯分类模型
// 需要将 JavaRDD 转化为 org.apache.spark.rdd.RDD
// 设置训练集 , 以及拉普拉斯估计值
NaiveBayesModel model = NaiveBayes.train(training.rdd() , 1);
// 9.对测试样本进行测试 并获取预测的值
JavaPairRDD<Double , Double> predictResult = test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint labeledPoint) throws Exception {
// 获取预测的值
Double res = model.predict(labeledPoint.features());
// 返回值
return new Tuple2<>(res , labeledPoint.label());
}
});
predictResult.foreach(new VoidFunction<Tuple2<Double, Double>>() {
@Override
public void call(Tuple2<Double, Double> value) throws Exception {
System.out.println("预测的结果 : " + value._1 + "\t" + "实际的结果 : " + value._2);
}
});
predictResult.take(100);
// 统计预测成功的个数
Long sucess = predictResult.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> v1) throws Exception {
return v1._1().equals(v1._2()); // 注意 这里是 Double类型的 不能直接用 ==
}
}).count();
System.out.println("sucess = " + sucess + "\t test = " + test.count());
double rate = (((double)sucess) / test.count() ) * 100;
System.out.println("预测的成功率是 :" + String.format("%.2f" , rate)+"%");
// 保存模型
//model.save(sc.sc() , "model/my_navie_bayes_model");
// 关闭 SparkContext
sc.stop();
}
然后是预测的代码 :
/**
* 预测邮件是否为垃圾邮件
*
* @param predictData 要预测的邮件
* @return 返回预测结果 如果是 0 代表是正常邮件 如果是 1 代表邮件
*/
public double predictChinese(String predictData) {
// 获取 SparkContext 对象
JavaSparkContext sc = BayesUtils.init();
// 获取模型
String modelPath = Predicted.class.getClassLoader().getResource("model_cn/naive_model_optimize").getPath();
NaiveBayesModel model = NaiveBayesModel.load(sc.sc() , modelPath);
// 获取词袋
// 此处加载 resources 的资源时 一定要 获取类加载器
String wordBagPath = Predicted.class.getClassLoader().getResource("wordbags_cn").getPath();
JavaRDD<String> wordBagRDD = sc.textFile(wordBagPath);
TreeMap<String, Integer> wordBags = BayesUtils.getWordBag(wordBagRDD.collect());
System.out.println("wordBags : ==>" + wordBags.size());
// 获取停用单词
String stopwWordPath = Predicted.class.getClassLoader().getResource("stopWords.txt").getPath();
double[] datas = BayesUtils.getWordCount(predictData, wordBags);
// 开始预测
// 将语句向量封装到 LabeledPoint 对象中
LabeledPoint lp = new LabeledPoint(-1, new DenseVector(datas));
// 获取预测结果
double predictResult = model.predict(lp.features());
// 关闭 SparkContext
sc.stop();
return predictResult;
}
总结
maven 依赖
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<!-- spark 核心依赖包 -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<!-- spark mllib 依赖-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<!-- spark javaAPI 依赖-->
<dependency>
<groupId>com.sparkjava</groupId>
<artifactId>spark-core</artifactId>
<version>2.1</version>
</dependency>
<!-- spark sql 依赖-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.1.2</version>
</dependency>
<!-- jcesg 分词 -->
<dependency>
<groupId>org.lionsoul</groupId>
<artifactId>jcseg-core</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>org.lionsoul</groupId>
<artifactId>jcseg-analyzer</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>org.lionsoul</groupId>
<artifactId>jcseg-elasticsearch</artifactId>
<version>2.4.0</version>
</dependency>
<!-- 结巴 分词-->
<!-- https://mvnrepository.com/artifact/com.huaban/jieba-analysis -->
<dependency>
<groupId>com.huaban</groupId>
<artifactId>jieba-analysis</artifactId>
<version>1.0.2</version>
</dependency>
<!-- scala 库-->
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>2.11.8</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.scala-lang/scala-compiler -->
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
<version>2.11.8</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.scala-lang/scala-reflect -->
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
<version>2.11.8</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>RELEASE</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>RELEASE</version>
<scope>compile</scope>
</dependency>
</dependencies>
构建模型的代码 :
package com.wangt.bayes.test;
import com.wangt.bayes.utils.BayesUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;
import java.util.List;
import java.util.TreeMap;
/**
* @author wangt
* @create 2019-03-06 20:35
*/
public class TrainModel {
/* public static NaiveBayesModel trainModel(String trainDataPath){
}
*/
public static void main(String[] args) {
// 1.获取 SparkContext
JavaSparkContext sc = BayesUtils.init();
// 2.读取样本数据
JavaRDD<String> lines = sc.textFile("input/navie_bayes_data.txt");
List<String> words = lines.take((int) lines.count());
// 3.使用结巴分词器进行分词 获取不重复的单词
TreeMap<String , Integer> wordBag = BayesUtils.getWordBag(words);
// 4.对数据进行切分 将切分后的 (标签 , 内容) 封装到 LabeledPoint 对象中 ,并将所有的LabeledPoint存储到 RDD中
JavaRDD<LabeledPoint> parsedData = lines.map(new Function<String, LabeledPoint>() {
@Override
public LabeledPoint call(String v1) throws Exception {
// 对数据进行按行切分
String[] fields = v1.split(",");
double key = 0.0;
if(fields[0].equals("ham")){ // ham = 0 , spam = 1 这里特别注意 是fields[0] 不是fields
key = 0;
}else {
key = 1;
}
// 获取单词词频
double[] datas = BayesUtils.getWordCount(fields[1] , wordBag);
// 返回封装好的值
return new LabeledPoint(key ,new DenseVector(datas));
}
});
// 6.样本数据划分 训练样本 和测试样本
// 一个参数是 训练样本和测试的比例 , 另一个是随机数的种子
JavaRDD<LabeledPoint>[] splits = parsedData.randomSplit(new double[]{0.9 , 0.1} , 100L);
// 7.获取样本集
// 获取训练样本集
JavaRDD<LabeledPoint> training = splits[0];
// 获取测试样本集
JavaRDD<LabeledPoint> test = splits[1];
// 8.对数据进行训练 , 新建贝叶斯分类模型
// 需要将 JavaRDD 转化为 org.apache.spark.rdd.RDD
// 设置训练集 , 以及拉普拉斯估计值
NaiveBayesModel model = NaiveBayes.train(training.rdd() , 1);
// 9.对测试样本进行测试 并获取预测的值
JavaPairRDD<Double , Double> predictResult = test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint labeledPoint) throws Exception {
// 获取预测的值
Double res = model.predict(labeledPoint.features());
// 返回值
return new Tuple2<>(res , labeledPoint.label());
}
});
predictResult.foreach(new VoidFunction<Tuple2<Double, Double>>() {
@Override
public void call(Tuple2<Double, Double> value) throws Exception {
System.out.println("预测的结果 : " + value._1 + "\t" + "实际的结果 : " + value._2);
}
});
predictResult.take(100);
// 统计预测成功的个数
Long sucess = predictResult.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> v1) throws Exception {
return v1._1().equals(v1._2()); // 注意 这里是 Double类型的 不能直接用 ==
}
}).count();
System.out.println("sucess = " + sucess + "\t test = " + test.count());
double rate = (((double)sucess) / test.count() ) * 100;
System.out.println("预测的成功率是 :" + String.format("%.2f" , rate)+"%");
// 保存模型
//model.save(sc.sc() , "model/my_navie_bayes_model");
// 关闭 SparkContext
sc.stop();
}
}
封装好的工具方法
package com.wangt.bayes.utils;
import com.huaban.analysis.jieba.JiebaSegmenter;
import com.huaban.analysis.jieba.SegToken;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* @author wangt
* @create 2019-03-02 16:43
*/
public class BayesUtils {
/**
* 初始化 获取 SparkContext 对象
*
* @return
*/
public static JavaSparkContext init() {
// 1.创建 SparkConf 对象
SparkConf conf = new SparkConf();
// 2.设置 Spark 运行模式 以及 应用名称
conf.setMaster("local").setAppName("MyNaive_bayes");
// 3.获取 SparkContext 对象
/**
* SparkContext 是 Spark 程序的入口 相当于 Hadoop的Job
*/
return new JavaSparkContext(conf);
}
/**
* 对单个字符串进行切分 获取单词集合
*
* @param line 被切分的单词
* @return 返回存储被切分的单词的 集合
*/
public static List<String> splitWord(String line) {
// 创建结巴分词器
JiebaSegmenter segmenter = new JiebaSegmenter();
// 创建存储 分割后的单词的结婚
ArrayList<String> datas = new ArrayList<>();
// 分词
List<SegToken> list = segmenter.process(line, JiebaSegmenter.SegMode.SEARCH);
// 过滤掉空格
for (SegToken segToken : list) {
if (segToken.equals(" ")) {
continue;
}
datas.add(segToken.word);
}
return datas;
}
/**
* 对单个中文字符串进行切分 获取单词集合
*
* @param line 被切分的单词
* @return 返回存储被切分的单词的 集合
*/
public static List<String> splitChineseWord(String line) {
// 创建结巴分词器
JiebaSegmenter segmenter = new JiebaSegmenter();
// 创建存储 分割后的单词的结婚
ArrayList<String> datas = new ArrayList<>();
// 分词
List<SegToken> list = segmenter.process(line, JiebaSegmenter.SegMode.SEARCH);
// 过滤掉 无意义的单词
for (SegToken segToken : list) {
if (isChinese(segToken.word)) {
datas.add(segToken.word);
}
}
return datas;
}
/**
* 过滤掉停用单词
*
* @param words 需要过滤的单词
* @param stopWordsBags 停用单词库
* @return
*/
public static List<String> filterStopWords(List<String> words, List<String> stopWordsBags) {
// 存储 过滤后单词
ArrayList<String> lines = new ArrayList<>();
for (String word : words) {
if (!stopWordsBags.contains(word)) {
lines.add(word);
}
}
return lines;
}
/**
* 指定一个存储单词的词袋 , 切分指定字符串后得到单词 , 获取词袋中对应单词出现的次数组成double数组
*
* @param line 需要统计单词的语句
* @param wordBag 存放不重复单词的词袋
* @return 存放单词词频的double数组
*/
public static double[] getWordCount(String line, TreeMap<String, Integer> wordBag) {
// 创建存放单词词频的 对象
TreeMap<String, Integer> countWord = new TreeMap<>();
// 将词袋内的单词添加进 统计词频的对象中
countWord.putAll(wordBag);
// 对被统计的单词进行分词
List<String> wordDatas = splitWord(line);
// 统计单词词频
for (String wordData : wordDatas) {
// 如果被统计的单词在词袋中没有出现 , 将该单词丢弃
if (!countWord.containsKey(wordData)) {
continue;
}
countWord.replace(wordData, countWord.get(wordData) + 1);
}
// 将单词词频存放到 double 数组中
double[] values = new double[countWord.size()];
int size = 0;
Iterator<String> countKey = countWord.keySet().iterator();
while (countKey.hasNext()) {
String word = countKey.next();
values[size] = countWord.get(word);
size++;
}
// 返回存放单词词频的double数组
return values;
}
/**
* 统计 一个list 内的单词在对应词袋中出现的频率
*
* @param lines 存储一条语句切分后的单词
* @param wordBag 存放不重复单词的词袋
* @return 存放单词词频的double数组
*/
public static double[] getWordCount(List<String> lines, TreeMap<String, Integer> wordBag) {
// 创建存放单词词频的 对象
TreeMap<String, Integer> countWord = new TreeMap<>();
// 将词袋内的单词添加进 统计词频的对象中
countWord.putAll(wordBag);
for (String line : lines) {
// 统计单词词频
// 如果被统计的单词在词袋中没有出现 , 将该单词丢弃
if (countWord.containsKey(line)) {
countWord.replace(line, countWord.get(line) + 1);
}
}
// 将单词词频存放到 double 数组中
double[] values = new double[countWord.size()];
int size = 0;
Iterator<String> countKey = countWord.keySet().iterator();
while (countKey.hasNext()) {
String word = countKey.next();
values[size] = countWord.get(word);
size++;
}
// 返回存放单词词频的double数组
return values;
}
/**
* 切分一个字符串 , 获取不重复的单词 , 并存储到一个 Map 中
* key 为 单词 , value 默认是 0
*
* @param lines 获取重复单词的数据
* @return 存放不重复单词的词袋
*/
public static TreeMap<String, Integer> getWordBag(String lines) {
// 分词
List<String> values = splitWord(lines);
// 获取不重复的单词
TreeMap<String, Integer> wordBag = new TreeMap<>();
for (String line : values) {
wordBag.put(line, 0);
}
// 返回值
return wordBag;
}
/**
* 从一个存储字符串的 list集合中读取字符串 , 获取不重复的单词 存储到 map 中
* key 为 不重复的单词 , value 默认为 0
*
* @param lines 源数据
* @return 存放不重复单词的词袋
*/
public static TreeMap<String, Integer> getWordBag(List<String> lines) {
// 分词
List<String> values = new ArrayList<String>();
for (String line : lines) {
// 同时兼容可以切分 带有 ham,内容 的 和不带 ham的
String[] s = line.split(",");
String words = s.length > 1 ? s[1] : s[0];
values.addAll(splitWord(words));
}
// 过滤重复单词
TreeMap<String, Integer> wordBag = new TreeMap<>();
for (String value : values) {
wordBag.put(value, 0);
}
// 返回值
return wordBag;
}
/**
* 将从词袋文件中读取的单词转化为 更易使用的 Map
*
* @param wordBags
* @return
*/
public static TreeMap<String, Integer> getWordBagToMap(List<String> wordBags) {
TreeMap<String, Integer> wordBag = new TreeMap<>();
for (String value : wordBags) {
wordBag.put(value, 0);
}
// 返回值
return wordBag;
}
/**
* 指定一个文件的路径 , 读取文件中的单词 并且获取不重复的单词存储到一个 map中
* 其中 key 是 单词 , value 默认为 0
*
* @param sc SparkContext对象
* @param path 提取不重复单词的路径
* @return 词袋
*/
public static TreeMap<String, Integer> getWordBag(JavaSparkContext sc, String path) {
// 文件不存在时抛出异常
if (!new File(path).exists()) {
new RuntimeException("file is not exist");
}
// 读取文件
JavaRDD<String> lines = sc.textFile(path);
// 将 RDD 转化为 list
List<String> datas = lines.take((int) lines.count());
// 获取词袋 并发挥
return getWordBag(datas);
}
/**
* 判断 一个字符串是否包含中文
*
* @param word 需要判断的单词
* @return 如果是中文 则返回 true , 否则返回 false
*/
public static boolean isChinese(String word) {
// 正则匹配
// String parm="[\\u4e00-\\u9fa5]+"; // 表示一个或者多个中文
String parm = ".*[\\u4e00-\\u9faf].*"; // 表示一个或者多个中文
// 编译
Pattern pattern = Pattern.compile(parm);
Matcher m = pattern.matcher(word);
return m.matches();
}
/**
* 过滤一个字符串中不是中文的部分 然后将所有中文词存储到一个集合中返回
*
* @param lines 被过滤的集合
* @return 存储过滤后的数据的集合
*/
public static List<String> filterNotChineseWords(List<String> lines) {
// 存储过滤后的数据
ArrayList<String> newLines = new ArrayList<>();
// 对数据过滤
for (String line : lines) {
// 判断是否是中文
if (isChinese(line)) {
newLines.add(line);
}
}
return newLines;
}
/**
* 去除掉字符串中不是中文的部分 , 然后返回剩下是中文的字符串
*
* @param line 被过滤的字符串
* @return 存储过滤后的字符串
*/
public static String filterNotChineseWords(String line) {
// 对字符串进行切分
List<String> lines = BayesUtils.splitWord(line);
// 存储切分后的字符串
StringBuilder words = new StringBuilder();
// 对数据过滤
for (String s : lines) {
// 判断是否是中文
if (isChinese(s)) {
words.append(s); // 认真写代码
}
}
return words.toString();
}
public static void main(String[] args) {
List<String> stopwords = new ArrayList<>();
stopwords.add("java");
stopwords.add("c");
List<String> words = new ArrayList<>();
words.add("hello");
words.add("java");
words.add("c");
words.add("jk");
words = filterStopWords(words, stopwords);
for (String word : words) {
System.out.println(word);
}
}
}
预测的方法
/**
* 预测邮件是否为垃圾邮件
*
* @param predictData 要预测的邮件
* @return 返回预测结果 如果是 0 代表是正常邮件 如果是 1 代表邮件
*/
public double predictChinese(String predictData) {
// 获取 SparkContext 对象
JavaSparkContext sc = BayesUtils.init();
// 获取模型
String modelPath = Predicted.class.getClassLoader().getResource("model_cn/naive_model_optimize").getPath();
NaiveBayesModel model = NaiveBayesModel.load(sc.sc() , modelPath);
// 获取词袋
// 此处加载 resources 的资源时 一定要 获取类加载器
String wordBagPath = Predicted.class.getClassLoader().getResource("wordbags_cn").getPath();
JavaRDD<String> wordBagRDD = sc.textFile(wordBagPath);
TreeMap<String, Integer> wordBags = BayesUtils.getWordBag(wordBagRDD.collect());
System.out.println("wordBags : ==>" + wordBags.size());
// 获取停用单词
String stopwWordPath = Predicted.class.getClassLoader().getResource("stopWords.txt").getPath();
double[] datas = BayesUtils.getWordCount(predictData, wordBags);
// 开始预测
// 将语句向量封装到 LabeledPoint 对象中
LabeledPoint lp = new LabeledPoint(-1, new DenseVector(datas));
// 获取预测结果
double predictResult = model.predict(lp.features());
// 关闭 SparkContext
sc.stop();
return predictResult;
}