Spark读取MySQL(Oracle)数据保存为libsvm格式

libsvm数据格式:

libsvm使用的训练数据和检验数据文件格式如下:

 [label] [index1]:[value1] [index2]:[value2][label] [index1]:[value1] [index2]:[value2]

label 目标值,就是说class(属于哪一类),就是你要分类的种类,通常是一些整数。

index 是有顺序的索引,通常是连续的整数。就是指特征编号,必须按照升序排列

value 就是特征值,用来train的数据,通常是一堆实数组成。

即:

目标值   第一维特征编号:第一维特征值   第二维特征编号:第二维特征值 …
 
目标值   第一维特征编号:第一维特征值   第二维特征编号:第二维特征值 …
 
……
 
目标值   第一维特征编号:第一维特征值   第二维特征编号:第二维特征值 …

例如:5 1:0.6875 2:0.1875 3:0.015625 4:0.109375

表示训练用的特征有4维,第一维是0.6875,第二维是0.1875,第三维是0.015625,第四维是0.109375 目标值是5

注意:训练和测试数据的格式必须相同,都如上所示。测试数据中的目标值是为了计算误差用。

依赖:

	<properties>
        <scala.version>2.11.8</scala.version>
        <spark.version>2.2.0</spark.version>
    </properties>
    <repositories>
        <repository>
            <id>cloudera</id>
            <url>https://repository.cloudera.com/artifactory/cloudera-repos/</url>
        </repository>
    </repositories>
    <dependencies>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>${scala.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.11</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.11</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-client</artifactId>
            <version>2.6.0</version>
        </dependency>

        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-mapreduce-client-core</artifactId>
            <version>2.6.0</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.2.0</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib-local_2.11</artifactId>
            <version>2.2.0</version>
        </dependency>

        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>5.1.47</version>
        </dependency>
    </dependencies>

代码示例:

package com.spark.milib

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{
    
    DataFrame, Row, SparkSession}

object SqlDataToLibsvm {
    
    
  def main(args: Array[String]): Unit = {
    
    

    val sparkSession: SparkSession = SparkSession.builder().appName("test").master("local[4]").getOrCreate()

    val dataFrame: DataFrame = sparkSession.read.format("jdbc")
      .option("url", "jdbc:mysql://localhost:3306/test11")
      .option("dbtable", "tablename2")
      .option("user", "root")
      .option("password", "123")
      .load()

    dataFrame.createTempView("dataFrame")
    val frame: DataFrame = sparkSession.sql("select T_FACTOR,MP_ID,TG_ID,SJD,P0,P1,P2,P3,I1,I2,I3,U1,U2,U3 from dataFrame")

    val rdd: RDD[Row] = frame.rdd

    val data = rdd.map{
    
     line =>

      //此时line数据格式为:[22.0,11.0,11.0,12.0,22.0,33.0,22.0,23.0,24.0,11.0,23.0,13.0,25.0,23.0]
      //切割数据外面的中括号
      val lineStr: String = line.toString().substring(1,line.toString().length-2)
      val values = lineStr.toString().split(",").map(_.toDouble)
      //init返回除了最后一个元素的所有元素,作为特征向量
      //Vectors.dense向量化,dense密集型
      val feature = Vectors.dense(values.init)
      val label = values.last
      LabeledPoint(label, feature)
    }

    MLUtils.saveAsLibSVMFile(data,"D:\\test")
  }
}

效果:

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_44455388/article/details/107342126