制作测试数据源:
c1 85 c2 77 c3 88 c1 22 c1 66 c3 95 c3 54 c2 91 c2 66 c1 54 c1 65 c2 41 c4 65
spark scala实现代码:
import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession object GroupTopN1 { System.setProperty("hadoop.home.dir", "D:\\Java_Study\\hadoop-common-2.2.0-bin-master") case class Rating(userId: String, rating: Long) def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("ALS with ML Pipeline") val spark = SparkSession .builder() .config(sparkConf) .master("local") .config("spark.sql.warehouse.dir", "/") .getOrCreate() import spark.implicits._ import spark.sql val lines = spark.read.textFile("C:\\Users\\Administrator\\Desktop\\group.txt") val classScores = lines.map(line => Rating(line.split(" ")(0).toString, line.split(" ")(1).toLong)) classScores.createOrReplaceTempView("tb_test") var df = sql( s"""|select | userId, | rating, | row_number()over(partition by userId order by rating desc) rn |from tb_test |having(rn<=3) |""".stripMargin) df.show() spark.stop() } }
打印结果:
+------+------+---+ |userId|rating| rn| +------+------+---+ | c1| 85| 1| | c1| 66| 2| | c1| 65| 3| | c4| 65| 1| | c3| 95| 1| | c3| 88| 2| | c3| 54| 3| | c2| 91| 1| | c2| 77| 2| | c2| 66| 3| +------+------+---+
spark java代码实现:
import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import scala.Function1; import javax.management.RuntimeErrorException; import java.util.List; import java.util.ArrayList; public class Test { public static void main(String[] args) { System.out.println("Hello"); SparkConf sparkConf = new SparkConf().setAppName("ALS with ML Pipeline"); SparkSession spark = SparkSession .builder() .config(sparkConf) .master("local") .config("spark.sql.warehouse.dir", "/") .getOrCreate(); // Create an RDD JavaRDD<String> peopleRDD = spark.sparkContext() .textFile("C:\\Users\\Administrator\\Desktop\\group.txt", 1) .toJavaRDD(); // The schema is encoded in a string String schemaString = "userId rating"; // Generate the schema based on the string of schema List<StructField> fields = new ArrayList<>(); StructField field1 = DataTypes.createStructField("userId", DataTypes.StringType, true); StructField field2 = DataTypes.createStructField("rating", DataTypes.LongType, true); fields.add(field1); fields.add(field2); StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to Rows JavaRDD<Row> rowRDD = peopleRDD.map((Function<String, Row>) record -> { String[] attributes = record.split(" "); if(attributes.length!=2) { throw new Exception(); } return RowFactory.create(attributes[0],Long.valueOf( attributes[1].trim())); }); // Apply the schema to the RDD Dataset<Row> peopleDataFrame = spark.createDataFrame(rowRDD, schema); peopleDataFrame.createOrReplaceTempView("tb_test"); Dataset<Row> items = spark.sql("select userId,rating,row_number()over(partition by userId order by rating desc) rn " + "from tb_test " + "having(rn<=3)"); items.show(); spark.stop(); } }
输出结果同上边输出结果。