Spark MLlib中ALS交替最小二乘法推荐算法的使用

本文首发于我的个人博客QIMING.INFO,转载请带上链接及署名。

ALS(Alternating Least Square),交替最小二乘法。在机器学习中,特指使用最小二乘法的一种协同推荐算法。本文通过代码来演示用spark运行ALS算法的一个小例子。

算法简介

ALS算法通过观察到的所有用户给商品的打分,来推断每个用户的喜好并向用户推荐适合的商品。

其原理简单说就是假设用户评分矩阵是用户特征矩阵乘以物品特征矩阵得到的,即:A(m*n)=U(m*k)*V(k*n),然后得到一个评分矩阵。具体原理请自行查阅,本文主要为使用。

通常,调用ALS算法进行训练时有4个重要参数,分别是ratingsrankiterations,和lambda

  • ratings指用户提供的训练数据,它包括用户id集、商品id集以及相应的打分集;
  • rank表示隐含因素的数量,即特征的数量,也就是分解矩阵的k值。
  • iterations表示最大迭代次数;
  • lambda表示正则因子,可省略,默认为0.01。

运行步骤

数据说明

数据格式为:用户id,物品id,评分

[xuqm@cu01 ML_Data]$ cat input/test.data 
1,1,5.0
1,2,1.0
1,3,5.0
1,4,1.0
2,1,5.0
2,2,1.0
2,3,5.0
2,4,1.0
3,1,1.0
3,2,5.0
3,3,1.0
3,4,5.0
4,1,1.0
4,2,5.0
4,3,1.0
4,4,5.0

代码及说明


package nwpuhpc.antirisk.ml

import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.{SparkConf, SparkContext}

object ALSTest {

  // 构建Spark对象
  val conf = new SparkConf().setAppName("ALSTest")
  val sc = new SparkContext(conf)
  Logger.getRootLogger.setLevel(Level.WARN)

  // 读取样本数据
  val data = sc.textFile("/home/xuqm/ML_Data/input/test.data")
  val ratings = data.map(_.split(',') match {
    case Array(user, item, rate) =>
      Rating(user.toInt, item.toInt, rate.toDouble)
  })

  // 拆分成训练集和测试集
  val dataParts = ratings.randomSplit(Array(0.8, 0.2))
  val trainingRDD = dataParts(0).cache()
  val testRDD = dataParts(1)

  // 建立ALS交替最小二乘算法模型并训练
  val rank = 10
  val numIterations = 20
  val model = ALS.train(trainingRDD, rank, numIterations, 0.01)

  // 取出测试集中的用户id和商品id
  val usersProducts = testRDD.map {
    case Rating(user, product, rate) =>
      (user, product)
  }

  // 用训练好的模型预测测试集的结果
  val predictions = model.predict(usersProducts).map {
    case Rating(user, product, rate) =>
      ((user, product), rate)
  }

  val ratesAndPreds = testRDD.map {
    case Rating(user, product, rate) =>
      ((user, product), rate)
  }.join(predictions)

  // 输出误差
  val MSE = ratesAndPreds.map {
    case ((user, product), (r1, r2)) =>
      val err = (r1 - r2)
      err * err
  }.mean()
  println("Mean Squared Error = " + MSE)

  // 打印输出预测值
  println("User" + "\t" + "Products" + "\t" + "Rate" + "\t" + "Prediction")
  ratesAndPreds.collect.foreach(
    rating => {
      println(rating._1._1 + "\t" + rating._1._2 + "\t" + rating._2._1 + "\t" + rating._2._2)
    }
  )

}

结果展示

可以看出,误差不是很大。

猜你喜欢

转载自blog.csdn.net/u011630228/article/details/81364959