12.6 地图实时路况预测

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011418530/article/details/82177764

PredictLRwithLBFGS.scala:

/**训练-逻辑回归

 * Created by Administrator on 2018/8/24.

 */

object TrainLRwithLBFGS {

    val sparkConf = new SparkConf().setAppName("Beijing traffic").setMaster("local")

    val sc = new SparkContext(sparkConf)

    // create the date/time formatters

    val dayFormat = new SimpleDateFormat("yyyyMMdd")

    val minuteFormat = new SimpleDateFormat("HHmm")

    def main(args: Array[String]) {

        // fetch data from redis

        val jedis = RedisClient.pool.getResource

        jedis.select(1)

        // find relative road monitors for specified road

        // val camera_ids = List("310999003001","310999003102","310999000106","310999000205","310999007204")

        val camera_ids = List("310999003001","310999003102")

        val camera_relations:Map[String,Array[String]] = Map[String,Array[String]](

            "310999003001" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204"),

            "310999003102" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204")

        )

        val temp = camera_ids.map({ camera_id =>

            val hours = 3

            val nowtimelong = System.currentTimeMillis();

            val now = new Date(nowtimelong)

            val day = dayFormat.format(now)

            // Option Some None

            val list = camera_relations.get(camera_id).get

            val relations = list.map({ camera_id =>

                println(camera_id)

                // fetch records of one camera for three hours ago

                (camera_id, jedis.hgetAll(day + "_" + camera_id))

            })

            relations.foreach(println)

            // organize above records per minute to train data set format (MLUtils.loadLibSVMFile)

            val trainSet = ArrayBuffer[LabeledPoint]()

            // start begin at index 3

            //往训练集里面加上每一行数据,往前三小时,每次分钟数减1

            for(i <- Range(60*hours-3,0,-1)){

                val featuresX = ArrayBuffer[Double]()

                val featureY = ArrayBuffer[Double]()

                // get current minute and recent two minutes

                //填三分钟的数据

                for(index <- 0 to 2){

                    val tempOne = nowtimelong - 60 * 1000 * (i-index)

                    val d = new Date(tempOne)

                    val tempMinute = minuteFormat.format(d)

                    val tempNext = tempOne - 60 * 1000 * (-1)

                    val dNext = new Date(tempNext)

                    val tempMinuteNext = minuteFormat.format(dNext)

                    //把自己和相邻路段填入值

                    for((k,v) <- relations){

                        // k->camera_id ; v->HashMap

                        val map = v

                        //填入y坐标值

                        if(index == 2 && k == camera_id){

                            if (map.containsKey(tempMinuteNext)) {

                                val info = map.get(tempMinuteNext).split("_")

                                val f = info(0).toFloat / info(1).toFloat

                                featureY += f

                            }

                        }

                        //填入X坐标值

                        if (map.containsKey(tempMinute)){

                            val info = map.get(tempMinute).split("_")

                            val f = info(0).toFloat / info(1).toFloat

                            featuresX += f

                        } else{

                            featuresX += -1.0

                        }

                    }

                }

                //判断y坐标有没有值

                if(featureY.toArray.length == 1 ){

                    val label = (featureY.toArray).head

                    val record = LabeledPoint(if ((label.toInt/10)<10) (label.toInt/10) else 10.0, Vectors.dense(featuresX.toArray))

// println(record)

                    trainSet += record

                }

            }

            trainSet.foreach(println)

            println(trainSet.length)

            val data = sc.parallelize(trainSet)

            println(data)

            // Split data into training (60%) and test (40%).

            val splits = data.randomSplit(Array(0.6, 0.4), seed = 1000L)

            val training = splits(0)

            val test = splits(1)

            if(!data.isEmpty()){

                // Run training algorithm to build the model

              //按逻辑回归模型训练--算法调用

                val model = new LogisticRegressionWithLBFGS()

                        .setNumClasses(11)

                        .run(training)

                // Compute raw scores on the test set.

                //测试集测试

                val predictionAndLabels = test.map { case LabeledPoint(label, features) =>

                    val prediction = model.predict(features)

                    (prediction, label)

                }

                predictionAndLabels.foreach(x=> println(x))

                // Get evaluation metrics.

                //结果评估

                val metrics = new MulticlassMetrics(predictionAndLabels)

                val precision = metrics.precision

                println("Precision = " + precision)

                //如果准确度大于80%,保存起来

                if(precision > 0.8){

                    val path = "hdfs://node1:9000/model_"+camera_id+"_"+nowtimelong

                    model.save(sc, path)

                    println("saved model to "+ path)

                    jedis.hset("model", camera_id , path)

                }

            }

        })

        RedisClient.pool.returnResource(jedis)

    }

}

PredictLRwithLBFGS.scala:

/**加载存储的模型,进行预测

  * Created by Administrator on 2018/8/24.

  */

object PredictLRwithLBFGS {

    val sparkConf = new SparkConf().setAppName("Shanghai traffic").setMaster("local[4]")

    val sc = new SparkContext(sparkConf)

    // create the date/time formatters

    val dayFormat = new SimpleDateFormat("yyyyMMdd")

    val minuteFormat = new SimpleDateFormat("HHmm")

    val sdf = new SimpleDateFormat( "yyyy-MM-dd_HH:mm:ss" );

    def main(args: Array[String]) {

        val input = "2017-04-20_09:50:00"

        val date = sdf.parse( input );

        val inputTimeLong = date.getTime()

        val inputTime = new Date(inputTimeLong)

        val day = dayFormat.format(inputTime)

        // fetch data from redis

        val jedis = RedisClient.pool.getResource

        jedis.select(1)

        // find relative road monitors for specified road

        // val camera_ids = List("310999003001","310999003102","310999000106","310999000205","310999007204")

        val camera_ids = List("310999003001","310999003102")

        val camera_relations:Map[String,Array[String]] = Map[String,Array[String]](

            "310999003001" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204"),

            "310999003102" -> Array("310999003001","310999003102","310999000106","310999000205","310999007204")

        )

        val temp = camera_ids.map({ camera_id =>

            val list = camera_relations.get(camera_id).get

            val relations = list.map({ camera_id =>

                println(camera_id)

                // fetch records of one camera for three hours ago

                (camera_id, jedis.hgetAll(day + "_" + camera_id))

            })

            relations.foreach(println)

            // organize above records per minute to train data set format (MLUtils.loadLibSVMFile)

            val aaa = ArrayBuffer[Double]()

            // get current minute and recent two minutes

            for(index <- 0 to 2){

                val tempOne = inputTimeLong - 60 * 1000 * index

                val tempMinute = minuteFormat.format(inputTime)

                for((k,v) <- relations){

                    // k->camera_id ; v->speed

                    val map = v

                    if (map.containsKey(tempMinute)){

                        val info = map.get(tempMinute).split("_")

                        val f = info(0).toFloat / info(1).toFloat

                        aaa += f

                    } else{

                        aaa += -1.0

                    }

                }

            }

            // Run training algorithm to build the model

            //读取训练好的模型,进行预测

            val path = jedis.hget("model",camera_id)

            val model = LogisticRegressionModel.load(sc, path)

            // Compute raw scores on the test set.

            val prediction = model.predict(Vectors.dense(aaa.toArray))

            println(input+"\t"+camera_id+"\t"+prediction+"\t")

            jedis.hset(input, camera_id, prediction.toString)

        })

        RedisClient.pool.returnResource(jedis)

    }

 }

猜你喜欢

转载自blog.csdn.net/u011418530/article/details/82177764