Spark2.0 collaborative filtering algorithm described with ALS

ALS matrix factorization

A scoring matrix A can often be performed on a low-dimensional space with the abstract and the product of two small matrices is approximated, describe a person's preferences, it does not need to be listed favorites. Then abstractly, the characteristics of people's preferences and movies are invested in the low-dimensional space, a person's preferences mapped to a low-dimensional vector, a feature film of the same latitude into a vector, then the people of this movie similarity can be expressed as the product between the two vectors.
We understood similarity score, then the "scoring matrix A (m * n)" can be "user preference characteristic matrix U (m * k)" and a product "product feature matrix V (n * k)" of.
Optimization of matrix used in the decomposition process into two categories: cross-least-squares method (alternative least squares) and the stochastic gradient descent method (stochastic gradient descent).
Loss function comprises a regularization term (setRegParam).
Write pictures described here

Parameter selection

Block Number: blocking for parallel computing, default is 10. Regularization parameters: The default is 1. Rank: model number of factors hidden information display preferences -false, implicit preference information -true, default false (display) Alpha: only the implicit preference data, baseline reliability preference value. Nonnegative defined numBlocks IS The Number of Blocks The Users and items Will BE
Partitioned INTO in Order to Parallelize Computation (Defaults to
10). Rank IS The Number of latent Factors in The Model (Defaults to 10). MaxIter IS The maximum Number of Iterations to RUN (Defaults to 10). regparam specifies The regularization Parameter in the ALS (Defaults to 1.0). implicitPrefs specifies Whether to use The Explicit Feedback the ALS Variant or One Adapted for Implicit Feedback Data (Defaults to to false
Which means the using Explicit Feedback). alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference
observations (defaults to 1.0). nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false).

ALS = the ALS new new the ALS () 
          .setMaxIter ( 10) // maximum number of iterations, set occurs too java.lang.StackOverflowError 
          .setRegParam (0.16) // regularization parameter 
          .setAlpha (1.0 ) 
          .setImplicitPrefs ( to false ) 
          .setNonnegative ( to false ) 
          .setNumBlocks ( 10 ) 
          .setRank ( 10 ) 
          .setUserCol ( "the userId" ) 
          .setItemCol ( "MovieID" ) 
          .setRatingCol ( "Rating");

Note that the problem:
the user and the article item ID, based DataFrame API supports only integers, the maximum value of the range defined in integers.

The DataFrame-based API for ALS currently only supports integers for
user and item ids. Other numeric types are supported for the user and
item id columns, but the ids must be within the integer value range.

// loop regularization parameter, each given by the Evaluator RMSError 
      List the RMSE = new new the ArrayList (); // build a List save all the RMSE 
      for ( int I = 0; I <20; I ++) { // 20 times loop 
          Double the lambda = (I * +. 1. 5) * 0.01; // regparam 0.05 increase in accordance with 
          the ALS ALS = new new the ALS () 
          .setMaxIter ( . 5) // maximum number of iterations 
          .setRegParam (the lambda) // regularization parameter 
          .setUserCol ( "the userId" ) 
          .setItemCol ( "MovieID" ) 
          .setRatingCol ( "Rating" ); 
          ALSModel Model= als.fit(training);         
          // Evaluate the model by computing the RMSE on the test data
          Dataset predictions = model.transform(test);
          //RegressionEvaluator.setMetricName可以定义四种评估器
          //"rmse" (default): root mean squared error
          //"mse": mean squared error
          //"r2": R^2^ metric 
          //"mae": mean absolute error        
          RegressionEvaluator evaluator = new RegressionEvaluator()
          .setMetricName("rmse")//RMS Error
          .setLabelCol("rating")
          .setPredictionCol("prediction");
          Double rmse = evaluator.evaluate(predictions);
          RMSE.add(rmse);
          System.out.println("RegParam "+0.01*i+" RMSE " + rmse+"\n");        
      } 
      //输出所有结果
      for (int j = 0; j < RMSE.size(); j++) {
          Double lambda=(j*5+1)*0.01;
          System.out.println("RegParam= "+lambda+"  RMSE= " + RMSE.get(j)+"\n");    
    }
By designing a loop, optimum parameters can be studied, some results as follows: 
regparam the RMSE = 1.956 = 0.01 
regparam the RMSE = 1.166 = 0.06 
regparam the RMSE = 0.977 = 0.11 
regparam the RMSE = 0.962 = 0.16 // includes the RMSE minimum, the most suitable parameter 
the RMSE = 0.985 = 0.21 regparam 
regparam the RMSE = 1.021 = 0.26 
regparam the RMSE = 1.061 = 0.31 
regparam the RMSE = 1.102 = 0.36 
regparam the RMSE = 1.144 = 0.41 
regparam the RMSE = 1.228 = 0.51 
regparam the RMSE = 1.267 = 0.56 
regparam the RMSE = 1.300 = 0.61 
// the RegParam fixed at 0.16, the number of iterations continue to study the impact of 
the results of the following output in stand-alone environment, the iteration number is too large, there will be a java.lang.StackOverflowError exception. Due to the current thread's stack is full due. 
the RMSE = 1.7325. 1 = numMaxIteration 
numMaxIteration the RMSE = 1.0695. 4 =  
numMaxIteration. 7 the RMSE = 1.0563 =
numMaxIteration the RMSE = 1.055 = 10
numMaxIteration= 13  RMSE= 1.053
numMaxIteration= 16  RMSE= 1.053
//测试Rank隐含语义个数
Rank =1  RMSErr = 1.1584
Rank =3  RMSErr = 1.1067
Rank =5  RMSErr = 0.9366
Rank =7  RMSErr = 0.9745
Rank =9  RMSErr = 0.9440
Rank =11  RMSErr = 0.9458
Rank =13  RMSErr = 0.9466
Rank =15  RMSErr = 0.9443
Rank =17  RMSErr = 0.9543
// can define their own SPARK-SQL evaluation algorithm (as defined below with a mean absolute error calculation process)
 // the Register The DataFrame the SQL AS A Temporary View 
predictions.createOrReplaceTempView ( "tmp_predictions" );                                      
a Dataset absDiff = spark.sql ( "SELECT ABS (Prediction-Rating) from the diff tmp_predictions AS" );                    
absDiff.createOrReplaceTempView ( "tmp_absDiff" ); 
spark.sql ( "SELECT Mean (the diff) from tmp_absDiff absMeanDiff AS") Show ().;     

The complete code

public class Rating implements Serializable {...} 
can be found in http://spark.apache.org/docs/latest/ml-collaborative-filtering.html:
package my.spark.ml.practice.classification;

import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class myCollabFilter2 {  

    public static void main(String[] args) {
        SparkSession spark=SparkSession
                .builder()
                .appName("CoFilter")
                .master("local[4]")
                .config("spark.sql.warehouse.dir","file///:G:/Projects/Java/Spark/spark-warehouse" )
                .getOrCreate();

        String path="G:/Projects/CgyWin64/home/pengjy3/softwate/spark-2.0.0-bin-hadoop2.6/"
                + "data/mllib/als/sample_movielens_ratings.txt";

        //屏蔽日志
                Logger.getLogger("org.apache.spark") .setLevel (Level.WARN); 
                Logger.getLogger ( "org.eclipse.jetty.server" ) .setLevel (Level.OFF);   
         // ----------------- 1.0 ---------------------------- -------------- ready DataFrame
         // ..javaRDD () function to convert DataFrame RDD
         // then Map RDD for each row String-> Rating 
        JavaRDD ratingRDD = spark.read (). textFile (path) .javaRDD () 
                .map ( new new function () { 

                    @Override 
                    public Rating Call (String STR) throws Exception {                       
                         return Rating.parseRating (STR);  
                    }
                });
         // System.out.println (ratingRDD.take (10) .get (0) .getMovieId ()); 

        // the JavaRDD (each row is an instantiated objects Rating) Rating and Class Creating DataFrame 
        a Dataset = spark.createDataFrame ratings (ratingRDD, Rating. class );
         // ratings.show (30); 

        // data were randomly divided into training and testing sets 
        Double [] = weights new new  Double [] {0.8, 0.2 };
         Long = 1234 SEED ; 
        a Dataset [] Split = ratings.randomSplit (weights, SEED); 
        a Dataset Training = Split [0 ]; 
        a Dataset Test = Split [. 1 ];          

        //------------------------------ 2.0 ALS algorithm and the training data set to produce the recommended model --------- ---- 
        for ( int Rank =. 1; Rank <20 is; Rank ++ ) 
        { 
            // algorithm defines 
            the ALS ALS = new new the ALS () 
                    .setMaxIter ( . 5) //// maximum number of iterations, set occurs too java.lang. a StackOverflowError 
                    .setRegParam (0.16 )               
                    .setUserCol ( "userId" )                
                    .setRank (Rank) 
                    .setItemCol ( "MovieID" ) 
                    .setRatingCol ( "Rating" );
             //训练模型
            ALSModel model=als.fit(training);
            //---------------------------3.0 模型评估:计算RMSE,均方根误差---------------------
            Dataset predictions=model.transform(test);
            //predictions.show();
            RegressionEvaluator evaluator=new RegressionEvaluator()
                    .setMetricName("rmse")
                    .setLabelCol("rating")
                    .setPredictionCol("prediction");
            Double rmse=evaluator.evaluate(predictions);
            System.out.println("Rank =" + rank+"  RMSErr = " + rmse);               
        }       
    }
}

 

Guess you like

Origin www.cnblogs.com/a-du/p/10947743.html