今天学习了一个根据IP统计归属地的小案例,在此记录一下。
在电商网站后台都会记录用户的浏览日志,然后根据这些日志文件就可以做数据分析,比如统计用户的地址,喜好,这样就可以给用户推荐商品了。
那么怎样进行统计呢,首先我们要有一份各个省份的IP规则,然后要有一份日志文件,我们从日志文件中切分出IP字段,然后与IP规则进行对比,就可以匹配到是哪个地区的了。
我们先来写一下这个小案例的需求;
根据访问日志的IP地址计算出访问者的归属地,并且按照省份,计算出访问次数,然后将计算好的结果写入到MySQL中
1.整理数据,切分出IP字段,然后将IP地址转换成十进制
2.加载IP规则,整理规则,取出有用的字段,然后将数据缓存到内存中(Executor中的内存中)
3.将访问log与IP规则进行匹配(二分法查找)
4.取出对应的省份名称,然后将其和1组合在一起
5.按省份进行聚合
6.将聚合后的数据写入到MySQL中
难点:
首先我们要理解spark提交任务的机制以及RDD创建的机制,在此不在过多的阐述,可以查看博客:
RDD详解:https://blog.csdn.net/weixin_43866709/article/details/88623920
RDD之collect方法执行的过程:https://blog.csdn.net/weixin_43866709/article/details/88666080
难点就在于我们更好的使用IP规则这份数据
接下来我们一步一步的实现这个小需求
工具:spark集群,hdfs集群,MySQL,idea
1.加载IP规则,整理规则,取出有用的字段,然后将数据缓存到内存中(Executor中的内存中)
首先我们要将IP规则读取到hdfs中,这样可以保证IP规则这份数据不易丢失
val rulesLines: RDD[String] = sc.textFile(args(0))
然后整理IP规则,只取出有用的数据,比如用于比较的IP范围,还有对应的省份;
但是这里有一个问题,整理IP规则的是Task,是在Executor端执行的,这样每个Executor只是整理了部分的数据,后面得比较也是在Executor端执行的,这样会出现比较的错误。所以我们要将Executor处理完的IP规则收集到Driver端,这时Driver端的IP规则数据就是完整的了,再将Driver端的数据广播到Executor端,这样Executor端的数据就也是完整的了,就可以进行正确的比较了。
//整理ip规则数据
//这里是在Executor中执行的,每个Executor只计算部分的IP规则数据
val ipRulesRDD: RDD[(Long, Long, String)] = rulesLines.map(line => {
val fields = line.split("[|]")
val startNum = fields(2).toLong
val endNum = fields(3).toLong
val province = fields(6)
(startNum, endNum, province)
})
//需要将每个Executor端执行完的数据收集到Driver端
val rulesInDriver: Array[(Long, Long, String)] = ipRulesRDD.collect()
//再将Driver端的完整的数据广播到Executor端
//生成广播数据的引用
val broadcastRef: Broadcast[Array[(Long, Long, String)]] = sc.broadcast(rulesInDriver)
2.整理数据,切分出IP字段,然后将IP地址转换成十进制
3.将访问log与IP规则进行匹配(二分法查找)
4.取出对应的省份名称,然后将其和1组合在一起
首先我们先写一个小算法,用于将IP地址转换成十进制数字(这样更加便于比较)
TestIp.scala
//将IP转化为十进制
def ip2Long(ip: String): Long = {
val fragments = ip.split("[.]")
var ipNum = 0L
for (i <- 0 until fragments.length){
ipNum = fragments(i).toLong | ipNum << 8L
}
ipNum
}
再写一个小算法,用于IP地址的比较,因为IP规则是一个IP字段的范围,也就是说一个范围对应一个省份,要拿日志文件中的IP地址与这个范围进行比较,而且IP规则中的数据是排好序的,所以使用二分法查找会更加快捷:
TestIp.scala
//二分法查找
def binarySearch(lines: Array[(Long, Long, String)], ip: Long) : Int = {
var low = 0
var high = lines.length - 1
while (low <= high) {
val middle = (low + high) / 2
if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2))
return middle
if (ip < lines(middle)._1)
high = middle - 1
else {
low = middle + 1
}
}
-1
}
然后我们开始整理日志文件的数据,取出IP地址,转换成十进制,然后与IP规则进行比较:
//整理日志文件的数据,取出ip,转换成十进制,与IP规则进行比较(采用二分法)
val provinceAndOne: RDD[(String, Int)] = accessLines.map(line => {
val fields = line.split("[|]")
val ip = fields(1)
//将ip转换成十进制
val ipNum = TestIp.ip2Long(ip)
//让Executor通过广播数据的引用拿到广播的数据
//Task是在Driver端生成的,广播变量的引用是伴随着Task被发送到Executor端的
val rulesInExecutor: Array[(Long, Long, String)] = broadcastRef.value
//查找
var province = "未知"
val index: Int = TestIp.binarySearch(rulesInExecutor, ipNum)
if (index != -1) {
province = rulesInExecutor(index)._3
}
(province, 1)
})
5.按省份进行聚合
val reduced: RDD[(String, Int)] = provinceAndOne.reduceByKey(_+_)
6.将聚合后的数据写入到MySQL中
我们也提前将写入MySQL的规则写好:
def data2MySQL(it: Iterator[(String, Int)]): Unit = {
//一个迭代器代表一个分区,分区中有多条数据
//先获得一个JDBC连接
val conn: Connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/bigdata?characterEncoding=UTF-8", "用户", "密码")
//将数据通过Connection写入到数据库
val pstm: PreparedStatement = conn.prepareStatement("INSERT INTO access_log VALUES (?, ?)")
//将分区中的数据一条一条写入到MySQL中
it.foreach(tp => {
pstm.setString(1, tp._1)
pstm.setInt(2, tp._2)
pstm.executeUpdate()
})
//将分区中的数据全部写完之后,在关闭连接
if(pstm != null) {
pstm.close()
}
if (conn != null) {
conn.close()
}
}
在这里我们最好使用foreachPartition方法,一次拿出一个分区进行处理,这样一个分区使用一个jdbc连接,会更加节省资源。
reduced.foreachPartition(it => TestIp.data2MySQL(it))
到这里就处理完了,下面是完整的代码:
TestIp.scala
package XXX
import java.sql.{Connection, DriverManager, PreparedStatement}
import scala.io.{BufferedSource, Source}
object TestIp {
//将IP转化为十进制
def ip2Long(ip: String): Long = {
val fragments = ip.split("[.]")
var ipNum = 0L
for (i <- 0 until fragments.length){
ipNum = fragments(i).toLong | ipNum << 8L
}
ipNum
}
//定义读取ip.txt规则,只要有用的数据
def readRules(path:String):Array[(Long,Long,String)] = {
//读取ip.txt
val bf: BufferedSource = Source.fromFile(path)
//对ip.txt进行整理
val lines: Iterator[String] = bf.getLines()
//对ip进行整理,并放入内存
val rules: Array[(Long, Long, String)] = lines.map(line => {
val fileds = line.split("[|]")
val startNum = fileds(2).toLong
val endNum = fileds(3).toLong
val province = fileds(6)
(startNum, endNum, province)
}).toArray
rules
}
//二分法查找
def binarySearch(lines: Array[(Long, Long, String)], ip: Long) : Int = {
var low = 0
var high = lines.length - 1
while (low <= high) {
val middle = (low + high) / 2
if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2))
return middle
if (ip < lines(middle)._1)
high = middle - 1
else {
low = middle + 1
}
}
-1
}
def data2MySQL(it: Iterator[(String, Int)]): Unit = {
//一个迭代器代表一个分区,分区中有多条数据
//先获得一个JDBC连接
val conn: Connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/bigdata?characterEncoding=UTF-8", "用户", "密码")
//将数据通过Connection写入到数据库
val pstm: PreparedStatement = conn.prepareStatement("INSERT INTO access_log VALUES (?, ?)")
//将分区中的数据一条一条写入到MySQL中
it.foreach(tp => {
pstm.setString(1, tp._1)
pstm.setInt(2, tp._2)
pstm.executeUpdate()
})
//将分区中的数据全部写完之后,在关闭连接
if(pstm != null) {
pstm.close()
}
if (conn != null) {
conn.close()
}
}
def main(args: Array[String]): Unit = {
//数据是在内存中
val rules: Array[(Long, Long, String)] = readRules("E:/Spark视频/小牛学堂-大数据24期-06-Spark安装部署到高级-10天/spark-04-Spark案例讲解/课件与代码/ip/ip.txt")
//将ip地址转换成十进制
val ipNum = ip2Long("1.24.6.56")
//查找
val index = binarySearch(rules,ipNum)
//根据脚标到rules中查找对应的数据
val tp = rules(index)
val province = tp._3
println(province)
}
}
IpLocation.scala
package XXXX
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
object IpLocation2 {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("IpLocation2").setMaster("local[4]")
val sc = new SparkContext(conf)
//将ip.txt读取到HDFS中
val rulesLines: RDD[String] = sc.textFile(args(0))
//整理ip规则数据
//这里是在Executor中执行的,每个Executor只计算部分的IP规则数据
val ipRulesRDD: RDD[(Long, Long, String)] = rulesLines.map(line => {
val fields = line.split("[|]")
val startNum = fields(2).toLong
val endNum = fields(3).toLong
val province = fields(6)
(startNum, endNum, province)
})
//需要将每个Executor端执行完的数据收集到Driver端
val rulesInDriver: Array[(Long, Long, String)] = ipRulesRDD.collect()
//再将Driver端的完整的数据广播到Executor端
//生成广播数据的引用
val broadcastRef: Broadcast[Array[(Long, Long, String)]] = sc.broadcast(rulesInDriver)
//接下来开始读取访问日志数据
val accessLines: RDD[String] = sc.textFile(args(1))
//整理日志文件的数据,取出ip,转换成十进制,与IP规则进行比较(采用二分法)
val provinceAndOne: RDD[(String, Int)] = accessLines.map(line => {
val fields = line.split("[|]")
val ip = fields(1)
//将ip转换成十进制
val ipNum = TestIp.ip2Long(ip)
//让Executor通过广播数据的引用拿到广播的数据
//Task是在Driver端生成的,广播变量的引用是伴随着Task被发送到Executor端的
val rulesInExecutor: Array[(Long, Long, String)] = broadcastRef.value
//查找
var province = "未知"
val index: Int = TestIp.binarySearch(rulesInExecutor, ipNum)
if (index != -1) {
province = rulesInExecutor(index)._3
}
(province, 1)
})
//聚合
val reduced: RDD[(String, Int)] = provinceAndOne.reduceByKey(_+_)
reduced.foreachPartition(it => TestIp.data2MySQL(it))
//释放资源
sc.stop()
}
}
//这种方法是通过HDFS读取IP规则(ip.txt),在收集到Driver端,然后再广播到Executor端
//优点:IP规则更加安全,不容易丢失,而且不用和Driver在同一台机器