深入理解Spark ALS--源码解读与接口优化

源码见ALS.scalaMatrixFactorizationModel.scala

调用方法

import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.storage.StorageLevel

val spark = SparkSession.builder().enableHiveSupport().appName("RecommendProduct").getOrCreate()
val user_no = spark.sql("select pin,uid,cate from user_encoding_table").persist(StorageLevel.MEMORY_ONLY_SER)
val item_no = spark.sql("select itemid, item_id, fund_type from item_id_encoding_table").cache()
val ratings = spark.sql("select pin, item_id, rate from ratings_table" + partTableName + "_rate_dt where dt='" + date + "'")
      .join(user_no, "pin")
      .join(item_no, "item_id")
      .select("uid", "item_id", "rate")
      .rdd
      .map { case Row(uid: Int, item_id: Int, rate: Double) => Rating(uid.toInt, item_id.toInt, rate.toDouble) }

val model = ALS.trainImplicit(ratings, rank = 10, iterations = 5, lambda = 0.01, alpha = 0.01)
println("training finished.")

注意:这里输入的格式必须是Rating格式:

@Since("0.8.0")
case class Rating @Since("0.8.0") (
    @Since("0.8.0") user: Int,
    @Since("0.8.0") product: Int,
    @Since("0.8.0") rating: Double)

源码解读

//调用ALS对象的方法
@Since("0.8.0")
object ALS {
    @Since("0.8.1")
      def trainImplicit(
          ratings: RDD[Rating],
          rank: Int,
          iterations: Int,
          lambda: Double,
          blocks: Int,
          alpha: Double,
          seed: Long
        ): MatrixFactorizationModel = {
        //初始化一个ALS类的实例,并调用该实例的训练方法run
        new ALS(blocks, blocks, rank, iterations, lambda, true, alpha, seed).run(ratings)
      }
}

我们来看看ALS class是怎么定义,训练的:

//ALS训练算法真正的实现是在ml库中
import org.apache.spark.ml.recommendation.{ALS => NewALS}

class ALS private (
    private var numUserBlocks: Int,
    private var numProductBlocks: Int,
    private var rank: Int,
    private var iterations: Int,
    private var lambda: Double,
    private var implicitPrefs: Boolean,
    private var alpha: Double,
    private var seed: Long = System.nanoTime()
  ) extends Serializable with Logging {

  /**
   * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10,
   * lambda: 0.01, implicitPrefs: false, alpha: 1.0}.
   */
  @Since("0.8.0")
  def this() = this(-1, -1, 10, 10, 0.01, false, 1.0)
  ……这里省略掉很多初始化成员属性的方法解释。

    //训练方法,返回一个矩阵向量模型
      def run(ratings: RDD[Rating]): MatrixFactorizationModel = {
        val sc = ratings.context
        //计算用户块数目,取spark并行度和ratings训练数据分区的一半的最大值
        val numUserBlocks = if (this.numUserBlocks == -1) {
          math.max(sc.defaultParallelism, ratings.partitions.length / 2)
        } else {
          this.numUserBlocks
        }
        //计算物品块,取spark并行度和ratings训练数据分区的一半的最大值
        val numProductBlocks = if (this.numProductBlocks == -1) {
          math.max(sc.defaultParallelism, ratings.partitions.length / 2)
        } else {
          this.numProductBlocks
        }
        //开始训练,返回用户因子向量和物品因子向量
        val (floatUserFactors, floatProdFactors) = NewALS.train[Int](
          //以ml库中的Rating格式化输入数据
          ratings = ratings.map(r => NewALS.Rating(r.user, r.product, r.rating.toFloat)),
          rank = rank,
          numUserBlocks = numUserBlocks,
          numItemBlocks = numProductBlocks,
          maxIter = iterations,
          regParam = lambda,
          implicitPrefs = implicitPrefs,
          alpha = alpha,
          nonnegative = nonnegative,
          intermediateRDDStorageLevel = intermediateRDDStorageLevel,
          finalRDDStorageLevel = StorageLevel.NONE,
          checkpointInterval = checkpointInterval,
          seed = seed)
        //存储用户和物品因子向量
        val userFactors = floatUserFactors
          .mapValues(_.map(_.toDouble))
          .setName("users")
          .persist(finalRDDStorageLevel)
        val prodFactors = floatProdFactors
          .mapValues(_.map(_.toDouble))
          .setName("products")
          .persist(finalRDDStorageLevel)
        if (finalRDDStorageLevel != StorageLevel.NONE) {
          userFactors.count()
          prodFactors.count()
        }
        //返回矩阵向量模型实例
        new MatrixFactorizationModel(rank, userFactors, prodFactors)
      }

我们首先来看看ALS究竟是怎么训练的呢?实现代码在ml库中的ALS.scala

 /**
   * :: DeveloperApi ::
   * Implementation of the ALS algorithm.
   */
  @DeveloperApi
  def train[ID: ClassTag]( // scalastyle:ignore
      ratings: RDD[Rating[ID]],
      rank: Int = 10,
      numUserBlocks: Int = 10,
      numItemBlocks: Int = 10,
      maxIter: Int = 10,
      regParam: Double = 1.0,
      implicitPrefs: Boolean = false,
      alpha: Double = 1.0,
      nonnegative: Boolean = false,
      intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
      finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
      checkpointInterval: Int = 10,
      seed: Long = 0L)(
      implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
    require(intermediateRDDStorageLevel != StorageLevel.NONE,
      "ALS is not designed to run without persisting intermediate RDDs.")
    val sc = ratings.sparkContext
    //生成用户块分区器实例
    val userPart = new ALSPartitioner(numUserBlocks)
    //生成物品块分区器实例
    val itemPart = new ALSPartitioner(numItemBlocks)
    //用户块本地分区索引编码
    val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
    //物品块本地分区索引编码
    val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
    //选择算法优化器
    val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
    val blockRatings = partitionRatings(ratings, userPart, itemPart)
      .persist(intermediateRDDStorageLevel)
    val (userInBlocks, userOutBlocks) =
      makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel)
    // materialize blockRatings and user blocks
    userOutBlocks.count()
    val swappedBlockRatings = blockRatings.map {
      case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
        ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
    }
    val (itemInBlocks, itemOutBlocks) =
      makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel)
    // materialize item blocks
    itemOutBlocks.count()
    val seedGen = new XORShiftRandom(seed)
    var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
    var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
    var previousCheckpointFile: Option[String] = None
    val shouldCheckpoint: Int => Boolean = (iter) =>
      sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0)
    val deletePreviousCheckpointFile: () => Unit = () =>
      previousCheckpointFile.foreach { file =>
        try {
          val checkpointFile = new Path(file)
          checkpointFile.getFileSystem(sc.hadoopConfiguration).delete(checkpointFile, true)
        } catch {
          case e: IOException =>
            logWarning(s"Cannot delete checkpoint file $file:", e)
        }
      }
    if (implicitPrefs) {
      for (iter <- 1 to maxIter) {
        userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
        val previousItemFactors = itemFactors
        itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
          userLocalIndexEncoder, implicitPrefs, alpha, solver)
        previousItemFactors.unpersist()
        itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
        // TODO: Generalize PeriodicGraphCheckpointer and use it here.
        val deps = itemFactors.dependencies
        if (shouldCheckpoint(iter)) {
          itemFactors.checkpoint() // itemFactors gets materialized in computeFactors
        }
        val previousUserFactors = userFactors
        userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
          itemLocalIndexEncoder, implicitPrefs, alpha, solver)
        if (shouldCheckpoint(iter)) {
          ALS.cleanShuffleDependencies(sc, deps)
          deletePreviousCheckpointFile()
          previousCheckpointFile = itemFactors.getCheckpointFile
        }
        previousUserFactors.unpersist()
      }
    } else {
      for (iter <- 0 until maxIter) {
        itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
          userLocalIndexEncoder, solver = solver)
        if (shouldCheckpoint(iter)) {
          val deps = itemFactors.dependencies
          itemFactors.checkpoint()
          itemFactors.count() // checkpoint item factors and cut lineage
          ALS.cleanShuffleDependencies(sc, deps)
          deletePreviousCheckpointFile()
          previousCheckpointFile = itemFactors.getCheckpointFile
        }
        userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
          itemLocalIndexEncoder, solver = solver)
      }
    }
    val userIdAndFactors = userInBlocks
      .mapValues(_.srcIds)
      .join(userFactors)
      .mapPartitions({ items =>
        items.flatMap { case (_, (ids, factors)) =>
          ids.view.zip(factors)
        }
      // Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks
      // and userFactors.
      }, preservesPartitioning = true)
      .setName("userFactors")
      .persist(finalRDDStorageLevel)
    val itemIdAndFactors = itemInBlocks
      .mapValues(_.srcIds)
      .join(itemFactors)
      .mapPartitions({ items =>
        items.flatMap { case (_, (ids, factors)) =>
          ids.view.zip(factors)
        }
      }, preservesPartitioning = true)
      .setName("itemFactors")
      .persist(finalRDDStorageLevel)
    if (finalRDDStorageLevel != StorageLevel.NONE) {
      userIdAndFactors.count()
      itemFactors.unpersist()
      itemIdAndFactors.count()
      userInBlocks.unpersist()
      userOutBlocks.unpersist()
      itemInBlocks.unpersist()
      itemOutBlocks.unpersist()
      blockRatings.unpersist()
    }
    (userIdAndFactors, itemIdAndFactors)
  }

猜你喜欢

转载自blog.csdn.net/fegnkuang/article/details/81487939