[HAOI2015]按位或(min-max容斥,FWT,FMT)

题目链接:洛谷

题目大意:给定正整数 $n$。一开始有一个数字 $0$,然后每一秒,都有 $p_i$ 的概率获得 $i$ 这个数 $(0\le i< 2^n)$。一秒恰好会获得一个数。每获得一个数,就要将我们有的数与获得的数进行按位或。问期望经过多少秒后,我们的数变成 $2^n-1$。

$1\le n\le 20,\sum p_i=1$。


%%%stO shadowice1984 Orz%%%

首先定义 $\min(S)$ 表示 $S$ 中第一个变为 $1$ 的元素的时间。(其中 $S$ 是一个二进制数,是 $1$ 的位表示这一位计入答案)

$\max(S)$ 表示最后一个变为 $1$ 的元素的时间。这也符合 $\min$ 和 $\max$ 的定义。

(以下定义全集 $U=2^n-1$)

我们要求的就是 $E(\max(U))$。

容斥:$E(\max(U))=\sum\limits_{S\neq\varnothing}(-1)^{|S|+1}E(\min(S))$

现在问题就是求 $E(\min(S))$,就是有元素变为 $1$。

套期望的公式:$E(\min(S))=\sum\limits^{+\infty}_{i=1}iP(\min(S)=i)$

$P(\min(S)=i)$ 表示恰好在第 $i$ 秒出现 $1$ 的概率。

前 $i-1$ 秒都没有出现,所以应该是 $\sum\limits_{T\cap S=\varnothing}P(T)$。其中 $P(T)$ 表示出现的数是 $T$ 的概率。

第 $i$ 秒有出现,所以应该是 $1-\sum\limits_{T\cap S=\varnothing}P(T)$。

乘法原理,$P(\min(S)=i)=(\sum\limits_{T\cap S=\varnothing}P(T))^{i-1}(1-\sum\limits_{T\cap S=\varnothing}P(T))$

那么经过一通爆算,$E(\min(S))=\dfrac{1}{1-\sum\limits_{T\cap S=\varnothing}P(T)}$。

转换一下:$E(\min(S))=\dfrac{1}{1-\sum\limits_{T\subseteq\complement_US}P(T)}$。

现在最严峻的问题就是计算每个集合的子集和。

最裸的枚举,$O(4^n)$。

技巧一点的枚举,$O(3^n)$。

那么就要说到一个很有趣的事情了,我也是做了这题才知道的……

回想一下FWT(或者FMT)做按位或卷积的时候:

$C_i=\sum\limits_{j|k=i}A_jB_k$

令 $\widehat{C}_i=\sum\limits_{j\subseteq i}C_j$

那么就有 $\widehat{C}_i=\sum\limits_{j|k\subseteq i}A_jB_k$

也就是 $\widehat{C}_i=\sum\limits_{j\subseteq i,k\subseteq i}A_jB_k$

也就是 $\widehat{C}_i=\widehat{A}_i\widehat{B}_i$!

于是需要一种从 $A$ 到 $\widehat{A}$ 的变换还有它的逆变换。

于是就有了FWT(或者FMT)……

(所以说,FWT比FFT要好理解,那就要理解啊)

那么我们就知道了,FWT的或变换一次后新的序列就是原序列的子集和形式!

好的,时间复杂度 $O(n2^n)$。

代码:(实际上特别好写)

#include<bits/stdc++.h>
using namespace std;
const int maxn=1111111;
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
#define ROF(i,a,b) for(int i=(a);i>=(b);i--)
#define MEM(x,v) memset(x,v,sizeof(x))
inline int read(){
    char ch=getchar();int x=0,f=0;
    while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar();
    while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
    return f?-x:x;
}
int n,cnt[maxn];
double p[maxn];
int main(){
    n=read();
    FOR(i,0,(1<<n)-1) scanf("%lf",p+i);
    for(int i=1;i<1<<n;i<<=1)
        for(int j=0,r=i<<1;j<1<<n;j+=r)
            FOR(k,0,i-1) p[i+j+k]+=p[j+k];
    FOR(i,1,(1<<n)-1) cnt[i]=cnt[i>>1]+(i&1);
    double ans=0;
    FOR(i,1,(1<<n)-1){
        if(1-p[((1<<n)-1)^i]<1e-8) return puts("INF"),0;
        ans+=1/(1-p[((1<<n)-1)^i])*(cnt[i]&1?1:-1);
    }
    printf("%.10lf\n",ans);
}
View Code

猜你喜欢

转载自www.cnblogs.com/1000Suns/p/10388001.html