Spark- 根据ip地址计算归属地

主要考察的是广播变量的使用:

1、将要广播的数据 IP 规则数据存放在HDFS上,(广播出去的内容一旦广播出去产就不能改变了,如果需要实时改变的规则,可以将规则放到Redis中)

2、在Spark中转成RDD,然后收集到Driver端,

3、把 IP 规则数据广播到Executor中。Driver端广播变量的引用是怎样跑到 Executor中的呢?  Task在Driver端生成的,广播变量的引用是伴随着Task被发送到Executor中的,广播变量的引用也被发送到Executor中,恰好指向HDFS

4、Executor执行分配到的 Task时,从Executor中获取 IP 规则数据做计算。

package com.rz.spark.base

import java.sql.{Connection, DriverManager, PreparedStatement}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

object IpLocation2 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[2]")
    val sc = new SparkContext(conf)

    // 取到HDFS中的 ip规则
    val rulesLine: RDD[String] = sc.textFile(args(0))

    // 整理ip规则数据
    val ipRulesRDD: RDD[(Long, Long, String)] = rulesLine.map(line => {
      val fields = line.split("[|]")
      val startNum = fields(2).toLong
      val endNum = fields(3).toLong
      val province = fields(6)
      (startNum, endNum, province)
    })
    // 将分散在多个Executor中的部分IP规则数据收集到Driver端
    val rulesInDriver: Array[(Long, Long, String)] = ipRulesRDD.collect()
    
    // 将Driver端的数据广播到Executor中
    // 调用sc上的广播方法
    // 广播变量的引用(还在Driver端中)
    val broadcastRef: Broadcast[Array[(Long, Long, String)]] = sc.broadcast(rulesInDriver)

    // 创建RDD,读取访问日志
    val accessLines: RDD[String] = sc.textFile(args(1))

    // 整理数据
    val provinceAndOne: RDD[(String, Int)] = accessLines.map(log => {
      // 将log日志的第一行进行切分
      val fields = log.split("[|]")
      val ip = fields(1)
      // 将ip转换成10进制
      val ipNum = MyUtils.ip2Long(ip)
      // 进行二分法查找,通过Driver端的引用获取到Executor中的广播变量
      // (该函数中的代码是在Executor中被调用执行的,通过广播变量的引用,就可以拿到当前Executor中的广播的ip二人规则)
      // Driver端广播变量的引用是怎样跑到 Executor中的呢?
      // Task在Driver端生成的,广播变量的引用是伴随着Task被发送到Executor中的,广播变量的引用也被发送到Executor中,恰好指向HDFS
      val rulesInExecutor: Array[(Long, Long, String)] = broadcastRef.value
      // 查找
      var province = "末知"
      val index = MyUtils.binarySearch(rulesInExecutor, ipNum)
      if (index != -1) {
        province = rulesInExecutor(index)._3
      }
      (province, 1)
    })
    // 聚合
    val reduced: RDD[(String, Int)] = provinceAndOne.reduceByKey(_+_)
    // 将结果打印
//    val result = reduced.collect()
//    println(result.toBuffer)

    // 将结果写入到MySQL中
    // 一次拿一个分区的每一条数据
    reduced.foreachPartition(it=>{
      val conn: Connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/bigdata?characterEncoding=utf-8","root","root")
      val pstm: PreparedStatement = conn.prepareStatement("insert into access_log values(?,?)")

      it.foreach(tp=>{
        pstm.setString(1, tp._1)
        pstm.setInt(2,tp._2)
        pstm.executeUpdate()
      })
      pstm.close()
      conn.close()
    })

    sc.stop()
  }
}

工具类

package com.rz.spark.base

import java.sql
import java.sql.{DriverManager, PreparedStatement}

import scala.io.{BufferedSource, Source}

object MyUtils {

  def ip2Long(ip: String): Long = {
    val fragments = ip.split("[.]")
    var ipNum = 0L
    for (i <- 0 until fragments.length){
      ipNum =  fragments(i).toLong | ipNum << 8L
    }
    ipNum
  }

  def readRules(path: String): Array[(Long, Long, String)] = {
    //读取ip规则
    val bf: BufferedSource = Source.fromFile(path)
    val lines: Iterator[String] = bf.getLines()
    //对ip规则进行整理,并放入到内存
    val rules: Array[(Long, Long, String)] = lines.map(line => {
      val fileds = line.split("[|]")
      val startNum = fileds(2).toLong
      val endNum = fileds(3).toLong
      val province = fileds(6)
      (startNum, endNum, province)
    }).toArray
    rules
  }

  def binarySearch(lines: Array[(Long, Long, String)], ip: Long) : Int = {
    var low = 0
    var high = lines.length - 1
    while (low <= high) {
      val middle = (low + high) / 2
      if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2))
        return middle
      if (ip < lines(middle)._1)
        high = middle - 1
      else {
        low = middle + 1
      }
    }
    -1
  }

  def data2MySQL(it: Iterator[(String, Int)]): Unit = {
    //一个迭代器代表一个分区,分区中有多条数据
    //先获得一个JDBC连接
    val conn: sql.Connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/bigdata?characterEncoding=UTF-8", "root", "123568")
    //将数据通过Connection写入到数据库
    val pstm: PreparedStatement = conn.prepareStatement("INSERT INTO access_log VALUES (?, ?)")
    //将分区中的数据一条一条写入到MySQL中
    it.foreach(tp => {
      pstm.setString(1, tp._1)
      pstm.setInt(2, tp._2)
      pstm.executeUpdate()
    })
    //将分区中的数据全部写完之后,在关闭连接
    if(pstm != null) {
      pstm.close()
    }
    if (conn != null) {
      conn.close()
    }
  }
}

猜你喜欢

转载自www.cnblogs.com/RzCong/p/10660563.html