PredictLRwithLBFGS.scala:
/**训练-逻辑回归 * Created by Administrator on 2018/8/24. */ object TrainLRwithLBFGS { val sparkConf = new SparkConf().setAppName("Beijing traffic").setMaster("local") val sc = new SparkContext(sparkConf) // create the date/time formatters val dayFormat = new SimpleDateFormat("yyyyMMdd") val minuteFormat = new SimpleDateFormat("HHmm") def main(args: Array[String]) { // fetch data from redis val jedis = RedisClient.pool.getResource jedis.select(1) // find relative road monitors for specified road // val camera_ids = List("310999003001","310999003102","310999000106","310999000205","310999007204") val camera_ids = List("310999003001","310999003102") val camera_relations:Map[String,Array[String]] = Map[String,Array[String]]( "310999003001" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204"), "310999003102" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204") ) val temp = camera_ids.map({ camera_id => val hours = 3 val nowtimelong = System.currentTimeMillis(); val now = new Date(nowtimelong) val day = dayFormat.format(now) // Option Some None val list = camera_relations.get(camera_id).get val relations = list.map({ camera_id => println(camera_id) // fetch records of one camera for three hours ago (camera_id, jedis.hgetAll(day + "_" + camera_id)) }) relations.foreach(println) // organize above records per minute to train data set format (MLUtils.loadLibSVMFile) val trainSet = ArrayBuffer[LabeledPoint]() // start begin at index 3 //往训练集里面加上每一行数据,往前三小时,每次分钟数减1 for(i <- Range(60*hours-3,0,-1)){ val featuresX = ArrayBuffer[Double]() val featureY = ArrayBuffer[Double]() // get current minute and recent two minutes //填三分钟的数据 for(index <- 0 to 2){ val tempOne = nowtimelong - 60 * 1000 * (i-index) val d = new Date(tempOne) val tempMinute = minuteFormat.format(d) val tempNext = tempOne - 60 * 1000 * (-1) val dNext = new Date(tempNext) val tempMinuteNext = minuteFormat.format(dNext) //把自己和相邻路段填入值 for((k,v) <- relations){ // k->camera_id ; v->HashMap val map = v //填入y坐标值 if(index == 2 && k == camera_id){ if (map.containsKey(tempMinuteNext)) { val info = map.get(tempMinuteNext).split("_") val f = info(0).toFloat / info(1).toFloat featureY += f } } //填入X坐标值 if (map.containsKey(tempMinute)){ val info = map.get(tempMinute).split("_") val f = info(0).toFloat / info(1).toFloat featuresX += f } else{ featuresX += -1.0 } } } //判断y坐标有没有值 if(featureY.toArray.length == 1 ){ val label = (featureY.toArray).head val record = LabeledPoint(if ((label.toInt/10)<10) (label.toInt/10) else 10.0, Vectors.dense(featuresX.toArray)) // println(record) trainSet += record } } trainSet.foreach(println) println(trainSet.length) val data = sc.parallelize(trainSet) println(data) // Split data into training (60%) and test (40%). val splits = data.randomSplit(Array(0.6, 0.4), seed = 1000L) val training = splits(0) val test = splits(1) if(!data.isEmpty()){ // Run training algorithm to build the model //按逻辑回归模型训练--算法调用 val model = new LogisticRegressionWithLBFGS() .setNumClasses(11) .run(training) // Compute raw scores on the test set. //测试集测试 val predictionAndLabels = test.map { case LabeledPoint(label, features) => val prediction = model.predict(features) (prediction, label) } predictionAndLabels.foreach(x=> println(x)) // Get evaluation metrics. //结果评估 val metrics = new MulticlassMetrics(predictionAndLabels) val precision = metrics.precision println("Precision = " + precision) //如果准确度大于80%,保存起来 if(precision > 0.8){ val path = "hdfs://node1:9000/model_"+camera_id+"_"+nowtimelong model.save(sc, path) println("saved model to "+ path) jedis.hset("model", camera_id , path) } } }) RedisClient.pool.returnResource(jedis) } } |
PredictLRwithLBFGS.scala:
/**加载存储的模型,进行预测 * Created by Administrator on 2018/8/24. */ object PredictLRwithLBFGS { val sparkConf = new SparkConf().setAppName("Shanghai traffic").setMaster("local[4]") val sc = new SparkContext(sparkConf) // create the date/time formatters val dayFormat = new SimpleDateFormat("yyyyMMdd") val minuteFormat = new SimpleDateFormat("HHmm") val sdf = new SimpleDateFormat( "yyyy-MM-dd_HH:mm:ss" ); def main(args: Array[String]) { val input = "2017-04-20_09:50:00" val date = sdf.parse( input ); val inputTimeLong = date.getTime() val inputTime = new Date(inputTimeLong) val day = dayFormat.format(inputTime) // fetch data from redis val jedis = RedisClient.pool.getResource jedis.select(1) // find relative road monitors for specified road // val camera_ids = List("310999003001","310999003102","310999000106","310999000205","310999007204") val camera_ids = List("310999003001","310999003102") val camera_relations:Map[String,Array[String]] = Map[String,Array[String]]( "310999003001" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204"), "310999003102" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204") ) val temp = camera_ids.map({ camera_id => val list = camera_relations.get(camera_id).get val relations = list.map({ camera_id => println(camera_id) // fetch records of one camera for three hours ago (camera_id, jedis.hgetAll(day + "_" + camera_id)) }) relations.foreach(println) // organize above records per minute to train data set format (MLUtils.loadLibSVMFile) val aaa = ArrayBuffer[Double]() // get current minute and recent two minutes for(index <- 0 to 2){ val tempOne = inputTimeLong - 60 * 1000 * index val tempMinute = minuteFormat.format(inputTime) for((k,v) <- relations){ // k->camera_id ; v->speed val map = v if (map.containsKey(tempMinute)){ val info = map.get(tempMinute).split("_") val f = info(0).toFloat / info(1).toFloat aaa += f } else{ aaa += -1.0 } } } // Run training algorithm to build the model //读取训练好的模型,进行预测 val path = jedis.hget("model",camera_id) val model = LogisticRegressionModel.load(sc, path) // Compute raw scores on the test set. val prediction = model.predict(Vectors.dense(aaa.toArray)) println(input+"\t"+camera_id+"\t"+prediction+"\t") jedis.hset(input, camera_id, prediction.toString) }) RedisClient.pool.returnResource(jedis) } } |