pom依赖
<dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.12</artifactId> <version>2.4.0</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.12</artifactId> <version>2.4.0</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.12</artifactId> <version>2.4.0</version> </dependency> <dependency> <groupId>com.thoughtworks.paranamer</groupId> <artifactId>paranamer</artifactId> <version>2.8</version> </dependency>
代码
import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; import org.apache.spark.mllib.feature.HashingTF; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import java.util.Arrays; import java.util.Iterator; /** * 逻辑回归算法 */ public class Regression { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("asd").setMaster("local[2]"); JavaSparkContext sc = new JavaSparkContext(conf); sc.setLogLevel("ERROR"); JavaRDD<String> data1 = sc.textFile("C:\\workspace\\sparkTest\\src\\test\\java\\data\\1.txt"); JavaRDD<String> data2 = sc.textFile("C:\\workspace\\sparkTest\\src\\test\\java\\data\\2.txt"); JavaRDD<String> javaRDD1 = data1.flatMap(new FlatMapFunction<String, String>() { @Override public Iterator<String> call(String s) throws Exception { String[] split = s.split(" "); return Arrays.asList(split).iterator(); } }); JavaRDD<String> javaRDD2 = data2.flatMap(new FlatMapFunction<String, String>() { @Override public Iterator<String> call(String s) throws Exception { String[] split = s.split(" "); return Arrays.asList(split).iterator(); } }); //创建一个示例,将文本映射成多个特征向量 final HashingTF tf = new HashingTF(10); //创建数据集分别存放数据 JavaRDD<LabeledPoint> map1 = javaRDD1.map(new Function<String, LabeledPoint>() {//类别1 @Override public LabeledPoint call(String s) throws Exception { return new LabeledPoint(0, tf.transform(Arrays.asList(s.split(" "))));//设置标签 } }); JavaRDD<LabeledPoint> map2 = javaRDD2.map(new Function<String, LabeledPoint>() {//类别2 @Override public LabeledPoint call(String s) throws Exception { return new LabeledPoint(1, tf.transform(Arrays.asList(s.split(" "))));//设置标签 } }); //合并结果并缓存 JavaRDD<LabeledPoint> trainData = map1.union(map2); trainData.cache();//使用迭代算法需要缓存 //使用SGD算法进行逻辑回归计算得到模型 LogisticRegressionModel model = new LogisticRegressionWithSGD().run(trainData.rdd()); //测试集 String t = "taltic ddd good gdds democracy fss care unassuming asdas mentality asdas"; String f = "good good asda sdasdas sada asd sad asd a"; //转化向量 Vector transformT = tf.transform(Arrays.asList(t.split(" "))); Vector transformF = tf.transform(Arrays.asList(f.split(" "))); //测试结果 System.out.println(t + " 结果为: " + model.predict(transformT)); System.out.println(f + " 结果为: " + model.predict(transformF)); trainData.unpersist(); sc.close(); } }