Convert RDD to DataFrame
Today, when using spark to process the movielens dataset, since a column needs to be added to the dataset, it is not possible to directly read the dataset to generate a DataFrame. It is necessary to preprocess the dataset to add a column before generating the DataFrame.
So I read in the data in the form of RDD and process it accordingly. After processing, I need to convert the RDD into a DataFrame to facilitate the use of the ml API.
There are two ways to convert RDD to DataFrame:
Use java's reflection mechanism. Use reflection to infer the schema of an RDD containing objects of a specific type. This approach simplifies the code and works well when you already know the schema.
First create a bean class
case class Person(name: String, age: Int)
Then convert Rdd to DataFrame
val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)).toDF()
people.registerTempTable("people")
- Using the programming interface, construct a schema and apply it to a known RDD.
Create a scheme first
val schema = StructType( schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true)))
Then apply the scheme to Rdd
val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim))
val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema)
The official website is described as follows:
When case classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), a DataFrame can be created programmatically with three steps.
- Create an RDD of Rows from the original RDD
- Create the schema represented by a StructType matching the structure of Rows in the RDD created in Step 1.
- Apply the schema to the RDD of Rows via createDataFrame method provided by SparkSession.
The dataset I use has a total of 4 columns, "userId", "movieId", "rating", "timestamp" I want to take the first 3 columns and add a column of favorite to indicate whether the rating is greater than 3
package ml
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types._
object movielens {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)
val spark = SparkSession
.builder
.appName("MovieLensExample")
.config("spark.sql.warehouse.dir", "file:///")
.master("local")
.getOrCreate()
val ratings = spark.sparkContext.textFile("F:\\program\\MyPrograms\\data\\ratings.csv")
.map(_.split(","))
.map(fields => Row(fields(0),fields(1),fields(2),fields(2).toDouble>3))//.toDF("userId","movieId","rating","Favorable")
val schema =
StructType(
StructField("userId", StringType, true) ::
StructField("movieId", StringType, true) ::
StructField("rating",StringType,true) ::
StructField("Favorable", BooleanType, true) :: Nil)
val ratingsDF = spark.createDataFrame(ratings,schema)
ratingsDF.show()
spark.stop()
}
}