抽样随机算法

水塘抽样/蓄水池抽样算法(Reservoir Sampling)

问题
对一个数量未知的样本,希望只经过一次遍历就能完成随机抽样,即时间复杂度O(n),每个元素被选中的概率一样。因为样本数量未知,是不能通过random函数直接随机抽样的。

解法

先选择第一个对象,以1/2的概率选择第二个,以1/3的概率选择第三个,以此类推,以1/m的概率选择第m个对象。当该过程结束时,每一个对象具有相同的选中概率,即1/n

论证
第m个对象最终被选中的概率P=前m个元素中选择m的概率*其后面所有元素不被选择的概率,即

  • 选择m的概率 : 1/m: 前m个元素选中它的概率就是我们规定的1/m

  • m后某个对象不被选择的概率为:假设为 m + k个元素, 那么概率为1 - 1 / (m + k) :因为m+k个元素每个元素选中的概率为 1/(m+k)(同m个元素一样), 所以不选中概率就是1 - 1 / (m + k) =m+k-1 / m +k

最后列出如下公式,可以看到最后选中m概率是1/n
在这里插入图片描述


例:N个元素随机选择k个元素

伪代码如下:

array S[N]; // 庞大的样本
array R[k]; // 水库
 
for(int i = 0; i < k; ++i){
    
    
	R[i] = S[i];
}

for(int i = k; i  < N; ++i){
    
    
    p = random(i); //[0,i] 随机一个数,1/i的概率
    if(p < k){
    
    
    	R[p] = S[i];
    }    
}

这样每个元素被选择的概率为k/N

解释如下

对于第k+1个元素:k+1随机概率小于k, 就是k/k+1, 即进入if语句的概率是k/k+1, 那么它被替换成的概率就是k/k+1, 同时前面k个元素,每个元素都是1/k+1的概率成为被替换对象,那么它不被替换的概率是1 - 1/k+1=k/k+1, 不被替换最终就出现, 即旧元素最终出现概率是k/k+1

而新元素在水库中出现的概率就一定是k/k+1(不管它替换掉前k个的哪个元素,反正肯定它是以这个概率出现在水库中的)

那么即论证了到k+1个元素,每个元素留在水库的概率是k/k+1,依此类推,k+2个元素是k/k+2, … , 到第N个元素则是k/N


例题: https://leetcode-cn.com/problems/linked-list-random-node/

class Solution {
    
    

    ListNode head;
    Random random;

    public Solution(ListNode head) {
    
    
        this.head = head;
        random = new Random();
    }

    public int getRandom() {
    
    
        // 第一元素选中
        int k = 1;
        ListNode node = head;

        // 第2个元素开始,1/i的概率
        int len = 2;
        ListNode s = head.next;
        while (s != null){
    
    
            // [0, len)
            int p = random.nextInt(len);
            if (p < k) {
    
    
                node = s;
            }

            len++;
            s = s.next;
        }

        return node.val;
    }
}

拒绝采样(reject sampling)

例题:https://leetcode-cn.com/problems/implement-rand10-using-rand7/

转自leetcode官方题解
在这里插入图片描述
作者:LeetCode-Solution
链接:https://leetcode-cn.com/problems/implement-rand10-using-rand7/solution/yong-rand7-shi-xian-rand10-by-leetcode-s-qbmd/
来源:力扣(LeetCode)
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

如上图,使用 ( rand7() - 1 ) * 7 + rand7() 生成随机数,可以生成49个数,但是只取前40个数,剩下的拒绝采样; 而采样的40个数每个概率都是1/49, 转化为1-10数字出现的概率则都是4/49,所以能等概率产生1-10,产生的数转换下即可

idx 序号如下,转换成1-10: 1 + (idx - 1) % 10

即:
11,21,31转换成1; 12,22,32 转换成2 ; 依此类推…
在这里插入图片描述

/**
 * The rand7() API is already defined in the parent class SolBase.
 * public int rand7();
 * @return a random integer in the range 1 to 7
 */
class Solution extends SolBase {
    
    
    public int rand10() {
    
    
        int idx = 41;
        while(idx > 40){
    
    
            idx = (rand7() - 1) * 7 + rand7();
        }
        return (idx - 1) % 10 + 1;
    }
}

问题:

  • 时间复杂度:期望时间复杂度为 O(1),但最坏情况下会达到 O(∞)(一直被拒绝)。
  • 空间复杂度:O(1)。

Todo 可以得到规律

(rand_X() - 1) × Y + rand_Y()  可以等概率的生成[1, X * Y]范围的随机数

随机权重选择问题

https://leetcode-cn.com/problems/random-pick-with-weight/

给你一个 下标从 0 开始 的正整数数组 w ,其中 w[i] 代表第 i 个下标的权重。

请你实现一个函数 pickIndex ,它可以 随机地 从范围 [0, w.length - 1] 内(含 0 和 w.length - 1)选出并返回一个下标。选取下标 i 的 概率 为 w[i] / sum(w) 。

例如,对于 w = [1, 3],挑选下标 0 的概率为 1 / (1 + 3) = 0.25 (即,25%),而选取下标 1 的概率为 3 / (1 + 3) = 0.75(即,75%)。


来源:力扣(LeetCode)
链接:https://leetcode-cn.com/problems/random-pick-with-weight
著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。

采用把区间放大的思路,即某个数权重大,给的区间也大,同样的概率分给该数个概率也就增大,完成按权重分概率

class Solution {
    
    
    int len;
    int[] pre;
    int total;
    
    public Solution(int[] w) {
    
    
        len = w.length;
        pre = new int[len];
        pre[0] = w[0];
        int sums = w[0];
        for (int i = 1; i < len; i ++) {
    
    
            pre[i] = pre[i - 1] + w[i];
            sums += w[i];
        }
        total = sums;
    }
    
    public int pickIndex() {
    
    
        // [1, total] 之间的随机数
        int x = (int) (Math.random() * total) + 1;
        // 找在哪个区间
        return binarySearch(x);
    }

    private int binarySearch(int x) {
    
    
        int low = 0;
        int high = len - 1;
        while (low < high) {
    
    
            int mid = (high - low) / 2 + low;
            if (pre[mid] < x) {
    
    
                low = mid + 1;
            } else {
    
    
                high = mid;
            }
        }
        return low;
    }
}

/**
 * Your Solution object will be instantiated and called as such:
 * Solution obj = new Solution(w);
 * int param_1 = obj.pickIndex();
 */

猜你喜欢

转载自blog.csdn.net/qq_26437925/article/details/124415531
今日推荐