FWT学习笔记

快速沃尔什变换学习笔记


\(or\)

\(f[i][x]\)表示第\(i+1\)位到第\(n\)位相同,第\(1\)位到第\(i\)位是\(x\)的子集的\(a[y]\)的和

于是FMT后的数组就是 \(f[n][x]\)

考虑如何计算\(f[i][x]\)

如果\(x\)的第\(i\)位是\(0\),那么\(f[i][x]=f[i-1][x]\)

如果是\(1\),那么\(f[i][x]=f[i-1][x]+f[i-1][x-2^{i-1}]\)

用滚动数组优化可以做到空间复杂度\(O(n)\)

对于第\(i\)层来说,相当于把整个序列分成了\(2^{n-i}\)

每一段中的第\(i+1\)位到第\(n\)位相同,且每段左半段第\(i\)位是\(0\),右半段第\(i\)位是\(1\),相当于左半段对右半段对应的位置产生了贡献

代码就很容易写出来了(^_^)

FMT的逆变换

与正变换类似,\(f[i][x]\)表示第\(i+1\)位到第\(n\)位是\(x\)的子集,且第\(1\)位到第\(i\)位相等的\(a[y]\)的和

如果\(x\)的第\(i\)位是\(0\),那么\(f[i][x]=f[i-1][x]\)

如果是\(1\),那么\(f[i][x]=f[i-1][x]-f[i-1][x-2^{i-1}]\)

是不是很简单(^_^)


\(and\)

\(or\)的本质相同


\(xor\)

这个就比较难了


代码:

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;

const int mm=998244353;
const int maxn=1000000;
const int inv2=499122177;

int n;
long long A[maxn];
long long B[maxn];
long long C[maxn];

void FMTor(long long *arr,int n,int f){
    for(int k=1;k<n;k<<=1){
        int p=k+k;
        for(int i=0;i<n;i+=p){
            for(int j=0;j<k;++j){
                if(f==1){
                    arr[i+j+k]=(arr[i+j+k]+arr[i+j])%mm;
                }else{
                    arr[i+j+k]=(arr[i+j+k]-arr[i+j]+mm)%mm;
                }
            }
        }
    }
}
void FMTand(long long *arr,int n,int f){
    for(int k=1;k<n;k<<=1){
        int p=k+k;
        for(int i=0;i<n;i+=p){
            for(int j=0;j<k;++j){
                if(f==1){
                    arr[i+j]=(arr[i+j]+arr[i+j+k])%mm;
                }else{
                    arr[i+j]=(arr[i+j]-arr[i+j+k]+mm)%mm;
                }
            }
        }
    }
}

void FWTxor(long long *arr,int n,int f){
    for(int k=1;k<n;k<<=1){
        int p=k+k;
        for(int i=0;i<n;i+=p){
            for(int j=0;j<k;++j){
                long long x=arr[i+j],y=arr[i+j+k];
                if(f==1){
                    arr[i+j]=(x+y)%mm;
                    arr[i+j+k]=(x-y+mm)%mm;
                }else{
                    arr[i+j]=(x+y)*inv2%mm;
                    arr[i+j+k]=(x-y+mm)*inv2%mm;
                }
            }
        }
    }
}

int main(){
    scanf("%d",&n);
    for(int i=0;i<(1<<n);++i)scanf("%lld",&A[i]);
    for(int i=0;i<(1<<n);++i)scanf("%lld",&B[i]);
    
    FMTor(A,1<<n,1);
    FMTor(B,1<<n,1);
    for(int i=0;i<(1<<n);++i)C[i]=A[i]*B[i]%mm;
    FMTor(A,1<<n,-1);
    FMTor(B,1<<n,-1);
    FMTor(C,1<<n,-1);
    for(int i=0;i<(1<<n);++i)printf("%lld ",C[i]);
    printf("\n");
    
    FMTand(A,1<<n,1);
    FMTand(B,1<<n,1);
    for(int i=0;i<(1<<n);++i)C[i]=A[i]*B[i]%mm;
    FMTand(A,1<<n,-1);
    FMTand(B,1<<n,-1);
    FMTand(C,1<<n,-1);
    for(int i=0;i<(1<<n);++i)printf("%lld ",C[i]);
    printf("\n");
    
    FWTxor(A,1<<n,1);
    FWTxor(B,1<<n,1);
    for(int i=0;i<(1<<n);++i)C[i]=A[i]*B[i]%mm;
    FWTxor(A,1<<n,-1);
    FWTxor(B,1<<n,-1);
    FWTxor(C,1<<n,-1);
    for(int i=0;i<(1<<n);++i)printf("%lld ",C[i]);
    printf("\n");
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zzyer/p/9285283.html
FWT