FWT学习小结

引入

有的时候我们需要进行这样的求和:
x y a [ x ] b [ y ] \sum_{x \otimes y} a[x]\cdot b[y]
其中 \otimes 为二元运算 a n d , o r , xor and ,or,\text{xor} 之一,即位运算卷积.

暴力显然是 O ( n 2 ) O(n^2) ,我们可不可以用类似 F F T FFT 的思想,把 a , b a,b 转化为 f w t [ a ] , f w t [ b ] fwt[a],fwt[b] (转点值),然后令
f w t [ c ] = f w t [ a ] f w t [ b ] fwt[c]=fwt[a]\cdot fwt[b] ,然后再对 f w t [ c ] fwt[c] 进行求逆呢?(由点值求系数) 答案是肯定的!

这样的正逆变换 称为 快速沃尔什变换.

o r or

f w t [ a ] [ i ] = i j = i a [ j ] fwt[a][i]=\sum_{i|j=i} a[j] .
我们把每个二进制看做一维的话,就是一个高维前缀和啦~~

a 0 , a 1 a_0,a_1 表示 a a 前后长度为 n / 2 n/2 的系数子序列,令 a 0 + a 1 a_0+a_1 表示对应位置相加, merge \text{merge} 表示序列相接,则有.

f w t [ a ] = merge ( f w t [ a 0 ] , f w t [ a 0 + a 1 ] ) fwt[a]=\text{merge}(fwt[a_0],fwt[a_0+a_1]) .

void fwt_or(int *f) {
    for(int k=1;k<n;k*=2)//维度
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)
                (f[i+j+k] += f[i+j]) %= mod;
}

现在需要证明的是
f w t [ c ] [ i ] = f w t [ a ] [ i ] f w t [ b ] [ i ] fwt[c][i]=fwt[a][i]\cdot fwt[b][i]
f w t [ a ] [ i ] f w t [ b ] [ i ] = ( i j = i a [ j ] ) ( k i = i b [ k ] ) fwt[a][i]\cdot fwt[b][i]=\left(\sum_{i|j=i} a[j] \right) \cdot \left( \sum_{k|i=i} b[k] \right)
i j = i , k i = i ( j k ) i = i 因为i|j=i,k|i=i\rightarrow (j|k)|i=i
f w t [ a ] [ i ] f w t [ b ] [ i ] = ( j k ) i a [ j ] b [ k ] = j i = i c [ j ] = f w t [ c ] [ i ] fwt[a][i]\cdot fwt[b][i]=\sum_{(j|k)|i} a[j] b[k]=\sum_{j|i=i}c[j]=fwt[c][i]

逆变换的时候,只要把正变换的影响消去即可.
a = I F W T ( f w t [ a ] ) = merge ( I F W T ( f w t [ a 0 ] ) , I F W T ( f w t [ a 1 ] f w t [ a 0 ] ) ) a=IFWT(fwt[a])=\text{merge}(IFWT(fwt[a_0]),IFWT(fwt[a_1]-fwt[a_0])) .

void ifwt_or(int *f) {
    for(int k=n/2;k;k/=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)
                f[i+j+k] = (f[i+j+k]-f[i+j]+mod)%mod;
}

你可能觉得这样的话,上面的 k k 必须从 n / 2 n/2 开始 f o r for ,其实从 k = 1 k=1 开始结果是一样的.
因为你把高低位互换不影响变换的正确性,这个东西在后面都适用,所以两个代码可以合并.

void fwt_or(int *f,ll x) {
    if(x==-1) x+=mod;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)
                add(f[i+j+k],f[i+j]*x%mod); 
}

a n d and

同理,设 f w t [ a ] [ i ] = j & i = i a [ j ] fwt[a][i]=\sum_{j\&i=i} a[j] .
因为 j & i = i , k & i = i ( j & k ) & i = i j\& i=i,k\& i=i\rightarrow (j\&k) \& i=i ,所以同理可证转点值后相乘的结果正确.

正逆变化类似.

void fwt_and(int *f,ll x) {
    if(x==-1) x+=mod;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++) 
                add(f[i+j],f[i+j+k]*x%mod);
}

xor \text{xor}

c n t ( i ) cnt(i) 表示 i i 二进制下有多少个1.
定义 x y = c n t ( x & y ) m o d    2 x\otimes y=cnt(x\&y)\mod 2 .
f w t [ a ] [ i ] = i j = 0 a [ j ] i j = 1 a [ j ] = a [ j ] ( 1 ) i j fwt[a][i]=\sum_{i \otimes j=0} a[j]-\sum_{i\otimes j=1} a[j]=\sum a[j]*(-1)^{i\otimes j} .

性质: ( i j )   xor   ( j k ) = i ( j   xor   k ) (i\otimes j) ~~\text{xor} ~~ (j\otimes k)=i\otimes(j~~\text{xor}~~k) .
证明: xor \text{xor} 为不进位加法,所以我们实际上是证明 c n t ( i & j ) + c n t ( i & k ) c n t ( i & ( j xor k ) ) ( m o d    2 ) cnt(i\&j)+cnt(i\&k)\equiv cnt(i\&(j\text{xor} k))(\mod 2)

我们对一个位的所有情况进行证明,那么总体就一定满足.

i i j j k k
0 - -
1 0 1
1 1 1

i = 0 i=0 显然都是0.
i = 1 , j + k = 1 i=1,j+k=1 ,则显然成立.
i = 1 , j + k = 2 i=1,j+k=2 ,左边为2,右边为0,成立!

综上:我们通过穷举证明了每一位的情况,也就是证明了所有的情况都满足.

现在依然是要证明:
f w t [ c ] [ i ] = f w t [ a ] [ i ] f w t [ b ] [ i ] fwt[c][i]=fwt[a][i]\cdot fwt[b][i]
f w t [ a ] [ i ] f w t [ b ] [ i ] = ( ( 1 ) i j a [ j ] ) ( ( 1 ) i k b [ k ] ) = ( 1 ) i j   xor   i k a [ j ] b [ k ] fwt[a][i]\cdot fwt[b][i]=\left( \sum (-1)^{i\otimes j} a[j]\right) \cdot \left( \sum (-1)^{i\otimes k} b[k]\right) =\sum (-1)^{i\otimes j~~\text{xor}~ ~i \otimes k}a[j]\cdot b[k]
( 1 ) i j   xor   i k a [ j ] b [ k ] = ( 1 ) i ( j  xor  k ) a [ j ] b [ k ] = ( 1 ) i j c [ j ] = f w t [ c ] [ i ] \sum (-1)^{i\otimes j~~\text{xor}~ ~i \otimes k}a[j]\cdot b[k]=\sum (-1)^{i\otimes(j ~\text{xor} ~k)} a[j] \cdot b[k]=\sum (-1)^{i\otimes j} c[j]=fwt[c][i]





正变换:
f w t [ a ] = merge ( f w t [ a 0 ] + f w t [ a 1 ] , f w t [ a 0 ] f w t [ a 1 ] ) fwt[a]=\text{merge} (fwt[a_0]+fwt[a_1],fwt[a_0]-fwt[a_1]) .

证明:在求解小规模数据时 a 1 a_1 时不知道自己最高位位1的,
此时, i a 0 , j a 1 , c n t ( i & j ) = c n t ( ( i + n / 2 ) & j ) = f w t [ a 1 ] [ i ] i\in a_0,j\in a_1,cnt(i\&j)=cnt((i+n/2)\&j)=fwt[a_1][i] .
i a 1 , j a 1 , c n t ( ( i + n / 2 ) & ( j + n / 2 ) ) = c n t ( i & j ) + 1 i\in a_1,j\in a_1,cnt((i+n/2)\& (j+n/2))=cnt(i\&j)+1 ,所以右边合并的时候 f w t [ a 1 ] fwt[a_1] 的符号改变.

逆变换:
I F W T ( a ) = merge ( I F W T ( a 0 + a 1 2 , I F W T ( a 0 a 1 2 ) ) IFWT(a)=\text{merge} (IFWT(\dfrac{a_0+a_1}2,IFWT(\dfrac {a_0-a_1} 2)) .

void fwt_xor(int *f,ll x) {
    if(x==-1) x=(mod+1)/2;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)  {
                int u=f[i+j],v=f[i+j+k];
                add(f[i+j],v); del(f[i+j+k]=u,v);
                f[i+j] = f[i+j]*x%mod;
                f[i+j+k] = f[i+j+k]*x%mod;
            }
}

模板题

板子:

int n,a[N],b[N],A[N],B[N];
void add(int &x,int y) {x+=y; if(x>=mod)  x-= mod;}
void upd(int &x) {x+=x>>31&mod;}
void del(int &x,int y) {upd(x-=y);}

void fwt_or(int *f,ll x) {
    if(x==-1) x+=mod;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)
                add(f[i+j+k],f[i+j]*x%mod); 
}

void fwt_and(int *f,ll x) {
    if(x==-1) x+=mod;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++) 
                add(f[i+j],f[i+j+k]*x%mod);
}

void fwt_xor(int *f,ll x) {
    if(x==-1) x=(mod+1)/2;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)  {
                int u=f[i+j],v=f[i+j+k];
                add(f[i+j],v); del(f[i+j+k]=u,v);
                f[i+j] = f[i+j]*x%mod;
                f[i+j+k] = f[i+j+k]*x%mod;
            }
}

void solve(void (*fwt)(int*f,ll x)) {
    for(int i=0;i<n;i++) a[i]=A[i],b[i]=B[i];
    fwt(a,1); fwt(b,1);
    for(int i=0;i<n;i++) a[i]=(ll)a[i]*b[i]%mod;
    fwt(a,-1);
    for(int i=0;i<n;i++) pr1(a[i]);
    puts("");
}

int main() {
    qr(n); n=1<<n;
    for(int i=0;i<n;i++) qr(A[i]);
    for(int i=0;i<n;i++) qr(B[i]);
    solve(fwt_or);
    solve(fwt_and);
    solve(fwt_xor);
    return 0;
}

例题

bzoj #4589. Hard Nim

有两个神在玩nim游戏,有 n n 堆石子,每堆石子的大小为 m \le m 的质数,求先手必败的方案数.
m 50000 , n 1 0 9 m\le 50000,n\le 10^9 .

定义一个多项式 f , f [ p ] = [ p m p p r i m e ] f,f[p]=[p\le m\cap p\in prime]
定义 f × g f\times g 表示 f , g f,g 对应位置相乘,区别于 f g f*g (表示卷积).
则我们要求的就是 I F W T ( f w t [ f ] n ) [ 0 ] IFWT(fwt[f]^n)[0] .
因为每乘一次就相当于作一次异或卷积,即多一堆石子,所以正确.

复杂度为 O ( m ( log m + log n ) ) O(m(\log m+\log n)) .


void add(int &x,int y) {x+=y; if(x>=mod) x-=mod;}
void upd(int &x) {x+=x>>31&mod;}
void del(int &x,int y) {upd(x -= y);}


int prime[N],tot; bool v[N];
void get(int x) {
    for(int i=2;i<=x;i++) {
        if(!v[i]) prime[++tot]=i;
        for(int j=1,k;(k=i*prime[j])<=x;j++) {
            v[k]=1;
            if(i%prime[j]==0) break;
        }
    }
}

int t;
void fwt(int *f,ll x) {
    if(x==-1) x=(mod+1)/2;
    for(int k=1;k<t;k*=2)
        for(int i=0;i<t;i+=2*k)
            for(int j=0;j<k;j++) {
                int u=f[i+j],v=f[i+j+k];
                add(f[i+j],v); del(f[i+j+k]=u,v);
                f[i+j]=f[i+j]*x%mod;
                f[i+j+k]=f[i+j+k]*x%mod;
            }
}

int n,m,f[N];
ll power(ll a,ll b=n) {
    ll c=1;
    while(b) {
        if(b&1) c=c*a%mod;
        b /= 2; a=a*a%mod;
    }
    return c;
}

int main() {

    get(N-1);
    while(~scanf("%d%d",&n,&m)) {
        for(t=1;t<=m;t*=2);
        memset(f,0,sizeof f);
        for(int i=1;i<=tot&&prime[i]<=m;i++) f[prime[i]]=1;;
        fwt(f,1);
        for(int i=0;i<t;i++) f[i]=power(f[i]);
        fwt(f,-1); pr2(f[0]);
    }
    return 0;
}

P3175 [HAOI2015]按位或

定义 m i n ( T ) min(T) 为取到 T T 集合中任意一位的最小时间.
m i n ( T ) = 1 S T p S = 1 1 f w t [ T ] min(T)=\dfrac 1{\sum_{S\cap T\ne \varnothing} p_S}=\dfrac 1 {1-fwt[\overline T]}

#include<bits/stdc++.h>
using namespace std;
const int N=(1<<20)|10;
const double eps=1e-9;

int n,g[N];
double f[N],ans;

void fwt(double *f) {
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++) 
                f[i+j+k] += f[i+j];
}

int main() {
    scanf("%d",&n); n=1<<n;
    for(int i=0;i<n;i++) scanf("%lf",&f[i]);
    fwt(f); g[0]=-1;
    for(int i=1;i<n;i++) {
        g[i]=-g[i&(i-1)];
        if(fabs(f[i^(n-1)]-1)<eps) 
            {puts("INF"); return 0;}
        ans += g[i]/(1-f[i^(n-1)]);
    }
    printf("%.10lf\n",ans); 
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_42886072/article/details/108072430
FWT