Spark:求出分组内的TopN

制作测试数据源:

c1 85
c2 77
c3 88
c1 22
c1 66
c3 95
c3 54
c2 91
c2 66
c1 54
c1 65
c2 41
c4 65

spark scala实现代码:

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object GroupTopN1 {
  System.setProperty("hadoop.home.dir", "D:\\Java_Study\\hadoop-common-2.2.0-bin-master")

  case class Rating(userId: String, rating: Long)

  def main(args: Array[String]) {
    val sparkConf = new SparkConf().setAppName("ALS with ML Pipeline")
    val spark = SparkSession
      .builder()
      .config(sparkConf)
      .master("local")
      .config("spark.sql.warehouse.dir", "/")
      .getOrCreate()

    import spark.implicits._
    import spark.sql

    val lines = spark.read.textFile("C:\\Users\\Administrator\\Desktop\\group.txt")
    val classScores = lines.map(line => Rating(line.split(" ")(0).toString, line.split(" ")(1).toLong))

    classScores.createOrReplaceTempView("tb_test")

    var df = sql(
      s"""|select
          | userId,
          | rating,
          | row_number()over(partition by userId order by rating desc) rn
          |from tb_test
          |having(rn<=3)
          |""".stripMargin)
    df.show()

    spark.stop()
  }
}

打印结果:

+------+------+---+
|userId|rating| rn|
+------+------+---+
|    c1|    85|  1|
|    c1|    66|  2|
|    c1|    65|  3|
|    c4|    65|  1|
|    c3|    95|  1|
|    c3|    88|  2|
|    c3|    54|  3|
|    c2|    91|  1|
|    c2|    77|  2|
|    c2|    66|  3|
+------+------+---+

spark java代码实现:

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Function1;

import javax.management.RuntimeErrorException;
import java.util.List;
import java.util.ArrayList;

public class Test {
    public static void main(String[] args) {
        System.out.println("Hello");
        SparkConf sparkConf = new SparkConf().setAppName("ALS with ML Pipeline");
        SparkSession spark = SparkSession
                .builder()
                .config(sparkConf)
                .master("local")
                .config("spark.sql.warehouse.dir", "/")
                .getOrCreate();


        // Create an RDD
        JavaRDD<String> peopleRDD = spark.sparkContext()
                .textFile("C:\\Users\\Administrator\\Desktop\\group.txt", 1)
                .toJavaRDD();

        // The schema is encoded in a string
        String schemaString = "userId rating";

        // Generate the schema based on the string of schema
        List<StructField> fields = new ArrayList<>();
        StructField field1 = DataTypes.createStructField("userId", DataTypes.StringType, true);
        StructField field2 = DataTypes.createStructField("rating", DataTypes.LongType, true);
        fields.add(field1);
        fields.add(field2);
        StructType schema = DataTypes.createStructType(fields);

        // Convert records of the RDD (people) to Rows
        JavaRDD<Row> rowRDD = peopleRDD.map((Function<String, Row>) record -> {
            String[] attributes = record.split(" ");
            if(attributes.length!=2)
            {
                throw new Exception();
            }
            return RowFactory.create(attributes[0],Long.valueOf( attributes[1].trim()));
        });

        // Apply the schema to the RDD
        Dataset<Row> peopleDataFrame = spark.createDataFrame(rowRDD, schema);

        peopleDataFrame.createOrReplaceTempView("tb_test");

        Dataset<Row> items = spark.sql("select userId,rating,row_number()over(partition by userId order by rating desc) rn " +
                "from tb_test " +
                "having(rn<=3)");
        items.show();

        spark.stop();
    }
}

输出结果同上边输出结果。

猜你喜欢

转载自www.cnblogs.com/yy3b2007com/p/9363474.html
今日推荐