在日常生活中,包括在设计计算机软件时,我们经常要判断一个元素是否在一个 集合中。比如在字处理软件中,需要检查一个英语单词是否拼写正确(也就是要判断它是否在已知的字典中);在 FBI,一个嫌疑人的名字是否已经在嫌疑名单上;在网络爬虫里,一个网址是否被访问过等等。最直接的方法就是将集合中全部的元素存在计算机中,遇到一个新 元素时,将它和集合中的元素直接比较即可。一般来讲,计算机中的集合是用哈希表(hash table)来存储的。它的好处是快速准确,缺点是费存储空间。当集合比较小时,这个问题不显著,但是当集合巨大时,哈希表存储效率低的问题就显现出来 了。比如说,一个象 Yahoo,Hotmail 和 Gmai 那样的公众电子邮件(email)提供商,总是需要过滤来自发送垃圾邮件的人(spamer)的垃圾邮件。一个办法就是记录下那些发垃圾邮件的 email 地址。由于那些发送者不停地在注册新的地址,全世界少说也有几十亿个发垃圾邮件的地址,将他们都存起来则需要大量的网络服务器。如果用哈希表,每存储一亿 个 email 地址, 就需要 1.6GB 的内存(用哈希表实现的具体办法是将每一个 email 地址对应成一个八字节的信息指纹googlechinablog.com/2006/08/blog-post.html , 然后将这些信息指纹存入哈希表,由于哈希表的存储效率一般只有 50%,因此一个 email 地址需要占用十六个字节。一亿个地址大约要 1.6GB, 即十六亿字节的内存)。因此存贮几十亿个邮件地址可能需要上百 GB 的内存。除非是超级计算机,一般服务器是无法存储的。
今天,我 们介绍一种称作布隆过滤器的数学工具,它只需要哈希表 1/8 到 1/4 的大小就能解决同样的问题。
布隆过滤器是由巴顿.布隆于一九 七零年提出的。它实际上是一个很长的二进制向量和一系列随机映射函数。我们通过上面的例子来说明起工作原理。
假定我们存储一亿个电子邮件 地址,我们先建立一个十六亿二进制(比特),即两亿字节的向量,然后将这十六亿个二进制全部设置为零。对于每一个电子邮件地址 X,我们用八个不同的随机数产生器(F1,F2, ...,F8) 产生八个信息指纹(f1, f2, ..., f8)。再用一个随机数产生器 G 把这八个信息指纹映射到 1 到十六亿中的八个自然数 g1, g2, ...,g8。现在我们把这八个位置的二进制全部设置为一。当我们对这一亿个 email 地址都进行这样的处理后。一个针对这些 email 地址的布隆过滤器就建成了。(见下图)
现在,让我们看看如何用布隆过滤器来检测一个可疑的电子邮件地址 Y 是否在黑名单中。我们用相同的八个随机数产生器(F1, F2, ..., F8)对这个地址产生八个信息指纹 s1,s2,...,s8,然后将这八个指纹对应到布隆过滤器的八个二进制位,分别是 t1,t2,...,t8。如果 Y 在黑名单中,显然,t1,t2,..,t8 对应的八个二进制一定是一。这样在遇到任何在黑名单中的电子邮件地址,我们都能准确地发现。
布隆过滤器决不会漏掉任何一个在黑名单中的可疑地址。但是,它有一条不足之处。也就是它有极小的可能将一个不在黑名单中的电子邮件地址判定为在黑名单中, 因为有可能某个好的邮件地址正巧对应个八个都被设置成一的二进制位。好在这种可能性很小。我们把它称为误识概率。在上面的例子中,误识概率在万分之一以 下。
布隆过滤器的好处在于快速,省空间。但是有一定的误识别率。常见的补救办法是在建立一个小的白名单,存储那些可能别误判的邮件地址。
java版本
import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.BitSet; import java.util.concurrent.atomic.AtomicInteger; public class BloomFileter implements Serializable { private static final long serialVersionUID = -5221305273707291280L; private final int[] seeds; private final int size; private final BitSet notebook; private final MisjudgmentRate rate; private final AtomicInteger useCount = new AtomicInteger(0); private final Double autoClearRate; /** * 默认中等程序的误判率:MisjudgmentRate.MIDDLE 以及不自动清空数据(性能会有少许提升) * * @param dataCount * 预期处理的数据规模,如预期用于处理1百万数据的查重,这里则填写1000000 */ public BloomFileter(int dataCount) { this(MisjudgmentRate.MIDDLE, dataCount, null); } /** * * @param rate * 一个枚举类型的误判率 * @param dataCount * 预期处理的数据规模,如预期用于处理1百万数据的查重,这里则填写1000000 * @param autoClearRate * 自动清空过滤器内部信息的使用比率,传null则表示不会自动清理, * 当过滤器使用率达到100%时,则无论传入什么数据,都会认为在数据已经存在了 * 当希望过滤器使用率达到80%时自动清空重新使用,则传入0.8 */ public BloomFileter(MisjudgmentRate rate, int dataCount, Double autoClearRate) { long bitSize = rate.seeds.length * dataCount; if (bitSize < 0 || bitSize > Integer.MAX_VALUE) { throw new RuntimeException("位数太大溢出了,请降低误判率或者降低数据大小"); } this.rate = rate; seeds = rate.seeds; size = (int) bitSize; notebook = new BitSet(size); this.autoClearRate = autoClearRate; } public void add(String data) { checkNeedClear(); for (int i = 0; i < seeds.length; i++) { int index = hash(data, seeds[i]); setTrue(index); } } public boolean check(String data) { for (int i = 0; i < seeds.length; i++) { int index = hash(data, seeds[i]); if (!notebook.get(index)) { return false; } } return true; } /** * 如果不存在就进行记录并返回false,如果存在了就返回true * * @param data * @return */ public boolean addIfNotExist(String data) { checkNeedClear(); int[] indexs = new int[seeds.length]; // 先假定存在 boolean exist = true; int index; for (int i = 0; i < seeds.length; i++) { indexs[i] = index = hash(data, seeds[i]); if (exist) { if (!notebook.get(index)) { // 只要有一个不存在,就可以认为整个字符串都是第一次出现的 exist = false; // 补充之前的信息 for (int j = 0; j <= i; j++) { setTrue(indexs[j]); } } } else { setTrue(index); } } return exist; } private void checkNeedClear() { if (autoClearRate != null) { if (getUseRate() >= autoClearRate) { synchronized (this) { if (getUseRate() >= autoClearRate) { notebook.clear(); useCount.set(0); } } } } } public void setTrue(int index) { useCount.incrementAndGet(); notebook.set(index, true); } private int hash(String data, int seeds) { char[] value = data.toCharArray(); int hash = 0; if (value.length > 0) { for (int i = 0; i < value.length; i++) { hash = i * hash + value[i]; } } hash = hash * seeds % size; // 防止溢出变成负数 return Math.abs(hash); } public double getUseRate() { return (double) useCount.intValue() / (double) size; } public void saveFilterToFile(String path) { try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(path))) { oos.writeObject(this); } catch (Exception e) { throw new RuntimeException(e); } } public static BloomFileter readFilterFromFile(String path) { try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(path))) { return (BloomFileter) ois.readObject(); } catch (Exception e) { throw new RuntimeException(e); } } /** * 清空过滤器中的记录信息 */ public void clear() { useCount.set(0); notebook.clear(); } public MisjudgmentRate getRate() { return rate; } /** * 分配的位数越多,误判率越低但是越占内存 * * 4个位误判率大概是0.14689159766308 * * 8个位误判率大概是0.02157714146322 * * 16个位误判率大概是0.00046557303372 * * 32个位误判率大概是0.00000021167340 * * @author lianghaohui * */ public enum MisjudgmentRate { // 这里要选取质数,能很好的降低错误率 /** * 每个字符串分配4个位 */ VERY_SMALL(new int[] { 2, 3, 5, 7 }), /** * 每个字符串分配8个位 */ SMALL(new int[] { 2, 3, 5, 7, 11, 13, 17, 19 }), // /** * 每个字符串分配16个位 */ MIDDLE(new int[] { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53 }), // /** * 每个字符串分配32个位 */ HIGH(new int[] { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131 }); private int[] seeds; private MisjudgmentRate(int[] seeds) { this.seeds = seeds; } public int[] getSeeds() { return seeds; } public void setSeeds(int[] seeds) { this.seeds = seeds; } } public static void main(String[] args) { BloomFileter fileter = new BloomFileter(7); System.out.println(fileter.addIfNotExist("1111111111111")); System.out.println(fileter.addIfNotExist("2222222222222222")); System.out.println(fileter.addIfNotExist("3333333333333333")); System.out.println(fileter.addIfNotExist("444444444444444")); System.out.println(fileter.addIfNotExist("5555555555555")); System.out.println(fileter.addIfNotExist("6666666666666")); System.out.println(fileter.addIfNotExist("1111111111111")); fileter.saveFilterToFile("C:\\Users\\john\\Desktop\\1111\\11.obj"); fileter = readFilterFromFile("C:\\Users\\john\\Desktop\\111\\11.obj"); System.out.println(fileter.getUseRate()); System.out.println(fileter.addIfNotExist("1111111111111")); } }
python版本
# -*- coding: utf-8 -*- import cmath from BitVector import BitVector ''' 布隆过滤器实现 ''' class BloomFilter(object): def __init__(self, error_rate, elementNum): # 计算所需要的bit数 self.bit_num = -1 * elementNum * cmath.log(error_rate) / (cmath.log(2.0) * cmath.log(2.0)) # 四字节对齐 self.bit_num = self.align_4byte(self.bit_num.real) # 分配内存 self.bit_array = BitVector(size=self.bit_num) # 计算hash函数个数 self.hash_num = cmath.log(2) * self.bit_num / elementNum self.hash_num = self.hash_num.real # 向上取整 self.hash_num = int(self.hash_num) + 1 # 产生hash函数种子 self.hash_seeds = self.generate_hashseeds(self.hash_num) def insert_element(self, element): for seed in self.hash_seeds: hash_val = self.hash_element(element, seed) # 取绝对值 hash_val = abs(hash_val) # 取模,防越界 hash_val = hash_val % self.bit_num # 设置相应的比特位 self.bit_array[hash_val] = 1 # 检查元素是否存在,存在返回true,否则返回false def is_element_exist(self, element): for seed in self.hash_seeds: hash_val = self.hash_element(element, seed) # 取绝对值 hash_val = abs(hash_val) # 取模,防越界 hash_val = hash_val % self.bit_num # 查看值 if self.bit_array[hash_val] == 0: return False return True # 内存对齐 def align_4byte(self, bit_num): num = int(bit_num / 32) num = 32 * (num + 1) return num # 产生hash函数种子,hash_num个素数 def generate_hashseeds(self, hash_num): count = 0 # 连续两个种子的最小差值 gap = 50 # 初始化hash种子为0 hash_seeds = [] for index in range(0, hash_num): hash_seeds.append(0) for index in range(10, 10000): max_num = int(cmath.sqrt(1.0 * index).real) flag = 1 for num in range(2, max_num): if index % num == 0: flag = 0 break if flag == 1: # 连续两个hash种子的差值要大才行 if count > 0 and (index - hash_seeds[count - 1]) < gap: continue hash_seeds[count] = index count = count + 1 if count == hash_num: break return hash_seeds def hash_element(self, element, seed): hash_val = 1 for ch in str(element): chval = ord(ch) hash_val = hash_val * seed + chval return hash_val